diff --git a/3rdparty/cutlass_fpA_intB_gemm b/3rdparty/cutlass_fpA_intB_gemm index 9033091ecb41..953121f18946 160000 --- a/3rdparty/cutlass_fpA_intB_gemm +++ b/3rdparty/cutlass_fpA_intB_gemm @@ -1 +1 @@ -Subproject commit 9033091ecb41e7387058147e11a7087d3b363c96 +Subproject commit 953121f18946cedf88c2ccb6439944956ad495a8 diff --git a/3rdparty/nvbench/l2_cache_flush.h b/3rdparty/nvbench/l2_cache_flush.h index 6c6ccc793814..42031fc00937 100644 --- a/3rdparty/nvbench/l2_cache_flush.h +++ b/3rdparty/nvbench/l2_cache_flush.h @@ -27,11 +27,11 @@ namespace tvm { namespace runtime { -#define CUDA_CALL(func) \ - { \ - cudaError_t e = (func); \ - ICHECK(e == cudaSuccess || e == cudaErrorCudartUnloading) \ - << "CUDA: " << cudaGetErrorString(e); \ +#define CUDA_CALL(func) \ + { \ + cudaError_t e = (func); \ + TVM_FFI_ICHECK(e == cudaSuccess || e == cudaErrorCudartUnloading) \ + << "CUDA: " << cudaGetErrorString(e); \ } class L2Flush { diff --git a/apps/cpp_rpc/main.cc b/apps/cpp_rpc/main.cc index e9a22628114b..b72220222072 100644 --- a/apps/cpp_rpc/main.cc +++ b/apps/cpp_rpc/main.cc @@ -143,7 +143,7 @@ string GetCmdOption(int argc, char* argv[], string option, bool key = false) { return cmd; } // We assume "=" is the end of option. - ICHECK_EQ(*option.rbegin(), '='); + TVM_FFI_ICHECK_EQ(*option.rbegin(), '='); cmd = arg.substr(arg.find('=') + 1); return cmd; } diff --git a/apps/cpp_rpc/rpc_env.cc b/apps/cpp_rpc/rpc_env.cc index 0ff063da41c2..a9d463a1aa84 100644 --- a/apps/cpp_rpc/rpc_env.cc +++ b/apps/cpp_rpc/rpc_env.cc @@ -157,7 +157,7 @@ RPCEnv::RPCEnv(const std::string& wd) { std::string bin; std::ifstream fs(file_name, std::ios::in | std::ios::binary); - ICHECK(!fs.fail()) << "Cannot open " << file_name; + TVM_FFI_ICHECK(!fs.fail()) << "Cannot open " << file_name; fs.seekg(0, std::ios::end); size_t size = static_cast(fs.tellg()); fs.seekg(0, std::ios::beg); @@ -199,7 +199,7 @@ std::vector ListDir(const std::string& dirname) { DIR* dp = opendir(dirname.c_str()); if (dp == nullptr) { int errsv = errno; - LOG(FATAL) << "ListDir " << dirname << " error: " << strerror(errsv); + TVM_FFI_THROW(InternalError) << "ListDir " << dirname << " error: " << strerror(errsv); } dirent* d; while ((d = readdir(dp)) != nullptr) { @@ -220,7 +220,7 @@ std::vector ListDir(const std::string& dirname) { HANDLE handle = FindFirstFileA(pattern.c_str(), &fd); if (handle == INVALID_HANDLE_VALUE) { const int errsv = GetLastError(); - LOG(FATAL) << "ListDir " << dirname << " error: " << strerror(errsv); + TVM_FFI_THROW(InternalError) << "ListDir " << dirname << " error: " << strerror(errsv); } do { std::string filename = fd.cFileName; @@ -235,7 +235,7 @@ std::vector ListDir(const std::string& dirname) { } while (FindNextFileA(handle, &fd)); FindClose(handle); #else - LOG(FATAL) << "Operating system not supported"; + TVM_FFI_THROW(InternalError) << "Operating system not supported"; #endif return vec; } @@ -260,7 +260,7 @@ void LinuxShared(const std::string output, const std::vector& files std::string err_msg; auto executed_status = support::Execute(cmd, &err_msg); if (executed_status) { - LOG(FATAL) << err_msg; + TVM_FFI_THROW(InternalError) << err_msg; } } #endif @@ -285,7 +285,7 @@ void WindowsShared(const std::string& output, const std::vector& fi std::string err_msg; const auto executed_status = support::Execute(cmd, &err_msg); if (executed_status) { - LOG(FATAL) << err_msg; + TVM_FFI_THROW(InternalError) << err_msg; } } #endif @@ -301,7 +301,7 @@ void CreateShared(const std::string& output, const std::vector& fil #elif defined(_WIN32) WindowsShared(output, files); #else - LOG(FATAL) << "Operating system not supported"; + TVM_FFI_THROW(InternalError) << "Operating system not supported"; #endif } @@ -323,7 +323,7 @@ std::string BuildSharedLibrary(std::string file) { std::string err_msg; const int executed_status = support::Execute(cmd, &err_msg); if (executed_status) { - LOG(FATAL) << err_msg; + TVM_FFI_THROW(InternalError) << err_msg; } CreateShared(file_name, ListDir(tmp_dir)); CleanDir(tmp_dir); diff --git a/apps/cpp_rpc/rpc_server.cc b/apps/cpp_rpc/rpc_server.cc index 8a948c9efad5..9425e734c309 100644 --- a/apps/cpp_rpc/rpc_server.cc +++ b/apps/cpp_rpc/rpc_server.cc @@ -262,15 +262,15 @@ class RPCServer { support::TCPSocket conn = listen_sock_.Accept(addr); int code = kRPCMagic; - ICHECK_EQ(conn.RecvAll(&code, sizeof(code)), sizeof(code)); + TVM_FFI_ICHECK_EQ(conn.RecvAll(&code, sizeof(code)), sizeof(code)); if (code != kRPCMagic) { conn.Close(); - LOG(FATAL) << "Client connected is not TVM RPC server"; + TVM_FFI_THROW(InternalError) << "Client connected is not TVM RPC server"; continue; } int keylen = 0; - ICHECK_EQ(conn.RecvAll(&keylen, sizeof(keylen)), sizeof(keylen)); + TVM_FFI_ICHECK_EQ(conn.RecvAll(&keylen, sizeof(keylen)), sizeof(keylen)); const char* CLIENT_HEADER = "client:"; const char* SERVER_HEADER = "server:"; @@ -282,10 +282,10 @@ class RPCServer { continue; } - ICHECK_NE(keylen, 0); + TVM_FFI_ICHECK_NE(keylen, 0); std::string remote_key; remote_key.resize(keylen); - ICHECK_EQ(conn.RecvAll(&remote_key[0], keylen), keylen); + TVM_FFI_ICHECK_EQ(conn.RecvAll(&remote_key[0], keylen), keylen); std::stringstream ssin(remote_key); std::string arg0; @@ -297,16 +297,16 @@ class RPCServer { if (arg0 != expect_header) { code = kRPCMismatch; - ICHECK_EQ(conn.SendAll(&code, sizeof(code)), sizeof(code)); + TVM_FFI_ICHECK_EQ(conn.SendAll(&code, sizeof(code)), sizeof(code)); conn.Close(); LOG(WARNING) << "Mismatch key from" << addr->AsString(); continue; } else { code = kRPCSuccess; - ICHECK_EQ(conn.SendAll(&code, sizeof(code)), sizeof(code)); + TVM_FFI_ICHECK_EQ(conn.SendAll(&code, sizeof(code)), sizeof(code)); keylen = int(server_key.length()); - ICHECK_EQ(conn.SendAll(&keylen, sizeof(keylen)), sizeof(keylen)); - ICHECK_EQ(conn.SendAll(server_key.c_str(), keylen), keylen); + TVM_FFI_ICHECK_EQ(conn.SendAll(&keylen, sizeof(keylen)), sizeof(keylen)); + TVM_FFI_ICHECK_EQ(conn.SendAll(server_key.c_str(), keylen), keylen); LOG(INFO) << "Connection success " << addr->AsString(); #ifndef __ANDROID__ ssin >> *opts; @@ -343,7 +343,7 @@ class RPCServer { size_t pos = opts.rfind(option); if (pos != std::string::npos) { const std::string cmd = opts.substr(pos + option.size()); - ICHECK(support::IsNumber(cmd)) << "Timeout is not valid"; + TVM_FFI_ICHECK(support::IsNumber(cmd)) << "Timeout is not valid"; return std::stoi(cmd); } return 0; diff --git a/apps/cpp_rpc/rpc_tracker_client.h b/apps/cpp_rpc/rpc_tracker_client.h index 329d2717de18..03a8dcf3a5fc 100644 --- a/apps/cpp_rpc/rpc_tracker_client.h +++ b/apps/cpp_rpc/rpc_tracker_client.h @@ -84,9 +84,9 @@ class TrackerClient { tracker_sock_ = ConnectWithRetry(); int code = kRPCTrackerMagic; - ICHECK_EQ(tracker_sock_.SendAll(&code, sizeof(code)), sizeof(code)); - ICHECK_EQ(tracker_sock_.RecvAll(&code, sizeof(code)), sizeof(code)); - ICHECK_EQ(code, kRPCTrackerMagic) << tracker_addr_.c_str() << " is not RPC Tracker"; + TVM_FFI_ICHECK_EQ(tracker_sock_.SendAll(&code, sizeof(code)), sizeof(code)); + TVM_FFI_ICHECK_EQ(tracker_sock_.RecvAll(&code, sizeof(code)), sizeof(code)); + TVM_FFI_ICHECK_EQ(code, kRPCTrackerMagic) << tracker_addr_.c_str() << " is not RPC Tracker"; std::ostringstream ss; ss << "[" << static_cast(TrackerCode::kUpdateInfo) << ", {\"key\": \"server:" << key_ @@ -95,7 +95,7 @@ class TrackerClient { // Receive status and validate std::string remote_status = tracker_sock_.RecvBytes(); - ICHECK_EQ(std::stoi(remote_status), static_cast(TrackerCode::kSuccess)); + TVM_FFI_ICHECK_EQ(std::stoi(remote_status), static_cast(TrackerCode::kSuccess)); } } /*! @@ -124,7 +124,7 @@ class TrackerClient { // Receive status and validate std::string remote_status = tracker_sock_.RecvBytes(); - ICHECK_EQ(std::stoi(remote_status), static_cast(TrackerCode::kSuccess)); + TVM_FFI_ICHECK_EQ(std::stoi(remote_status), static_cast(TrackerCode::kSuccess)); } else { *matchkey = key_; } @@ -174,7 +174,7 @@ class TrackerClient { tracker_sock_.SendBytes(ss.str()); std::string remote_status = tracker_sock_.RecvBytes(); - ICHECK_EQ(std::stoi(remote_status), static_cast(TrackerCode::kSuccess)); + TVM_FFI_ICHECK_EQ(std::stoi(remote_status), static_cast(TrackerCode::kSuccess)); unmatch_period_count = 0; } continue; @@ -206,7 +206,7 @@ class TrackerClient { auto period = (std::chrono::duration_cast( std::chrono::system_clock::now() - tbegin)) .count(); - ICHECK(period < timeout) << "Failed to connect to server" << addr.AsString(); + TVM_FFI_ICHECK(period < timeout) << "Failed to connect to server" << addr.AsString(); LOG(WARNING) << "Cannot connect to tracker " << addr.AsString() << " retry in " << retry_period << " seconds."; std::this_thread::sleep_for(std::chrono::seconds(retry_period)); diff --git a/apps/cpp_rpc/win32_process.cc b/apps/cpp_rpc/win32_process.cc index cc4c45d81a1d..888b751e0fe1 100644 --- a/apps/cpp_rpc/win32_process.cc +++ b/apps/cpp_rpc/win32_process.cc @@ -93,19 +93,19 @@ SOCKET GetSocket(const std::string& mmap_path) { UniqueHandle parent_file_mapping_event; if ((parent_file_mapping_event = MakeUniqueHandle( OpenEventA(SYNCHRONIZE, false, parent_event_name.c_str()))) == nullptr) { - LOG(FATAL) << "OpenEvent() failed: " << GetLastError(); + TVM_FFI_THROW(InternalError) << "OpenEvent() failed: " << GetLastError(); } UniqueHandle child_file_mapping_event; if ((child_file_mapping_event = MakeUniqueHandle( OpenEventA(EVENT_MODIFY_STATE, false, child_event_name.c_str()))) == nullptr) { - LOG(FATAL) << "OpenEvent() failed: " << GetLastError(); + TVM_FFI_THROW(InternalError) << "OpenEvent() failed: " << GetLastError(); } // Wait for the parent to set the event, notifying WSAPROTOCOL_INFO is ready to be read if (WaitForSingleObject(parent_file_mapping_event.get(), uint32_t(kEventTimeout.count())) != WAIT_OBJECT_0) { - LOG(FATAL) << "WaitForSingleObject() failed: " << GetLastError(); + TVM_FFI_THROW(InternalError) << "WaitForSingleObject() failed: " << GetLastError(); } const UniqueHandle file_map = @@ -129,7 +129,7 @@ SOCKET GetSocket(const std::string& mmap_path) { // Let the parent know we are finished duplicating the socket SetEvent(child_file_mapping_event.get()); } else { - LOG(FATAL) << "MapViewOfFile() failed: " << GetLastError(); + TVM_FFI_THROW(InternalError) << "MapViewOfFile() failed: " << GetLastError(); } return sock_duplicated; @@ -158,14 +158,14 @@ void SpawnRPCChild(SOCKET fd, seconds timeout) { UniqueHandle parent_file_mapping_event; if ((parent_file_mapping_event = MakeUniqueHandle( CreateEventA(nullptr, true, false, parent_event_name.c_str()))) == nullptr) { - LOG(FATAL) << "CreateEvent for parent file mapping failed"; + TVM_FFI_THROW(InternalError) << "CreateEvent for parent file mapping failed"; } UniqueHandle child_file_mapping_event; // An event to let the parent know the socket info was read from the mmap file if ((child_file_mapping_event = MakeUniqueHandle( CreateEventA(nullptr, true, false, child_event_name.c_str()))) == nullptr) { - LOG(FATAL) << "CreateEvent for child file mapping failed"; + TVM_FFI_THROW(InternalError) << "CreateEvent for child file mapping failed"; } char current_executable[MAX_PATH]; @@ -191,7 +191,7 @@ void SpawnRPCChild(SOCKET fd, seconds timeout) { WSAPROTOCOL_INFO protocol_info; // Get info needed to duplicate the socket if (WSADuplicateSocket(fd, child_process_info.dwProcessId, &protocol_info) == SOCKET_ERROR) { - LOG(FATAL) << "WSADuplicateSocket(): failed. Error =" << WSAGetLastError(); + TVM_FFI_THROW(InternalError) << "WSADuplicateSocket(): failed. Error =" << WSAGetLastError(); } // Create a mmap file to store the info needed for duplicating the SOCKET in the child proc @@ -203,7 +203,7 @@ void SpawnRPCChild(SOCKET fd, seconds timeout) { } if (GetLastError() == ERROR_ALREADY_EXISTS) { - LOG(FATAL) << "CreateFileMapping(): mapping file already exists"; + TVM_FFI_THROW(InternalError) << "CreateFileMapping(): mapping file already exists"; } else { void* map_view = MapViewOfFile(file_map.get(), FILE_MAP_READ | FILE_MAP_WRITE, 0, 0, 0); @@ -218,12 +218,13 @@ void SpawnRPCChild(SOCKET fd, seconds timeout) { if (WaitForSingleObject(child_file_mapping_event.get(), uint32_t(kEventTimeout.count())) != WAIT_OBJECT_0) { TerminateProcess(child_process_handle.get(), 0); - LOG(FATAL) << "WaitForSingleObject for child file mapping timed out. Terminating child " - "process."; + TVM_FFI_THROW(InternalError) + << "WaitForSingleObject for child file mapping timed out. Terminating child " + "process."; } } else { TerminateProcess(child_process_handle.get(), 0); - LOG(FATAL) << "MapViewOfFile() failed: " << GetLastError(); + TVM_FFI_THROW(InternalError) << "MapViewOfFile() failed: " << GetLastError(); } } @@ -254,7 +255,7 @@ void ChildProcSocketHandler(const std::string& mmap_path) { if ((socket = GetSocket(mmap_path)) != INVALID_SOCKET) { tvm::runtime::ServerLoopFromChild(socket); } else { - LOG(FATAL) << "GetSocket() failed"; + TVM_FFI_THROW(InternalError) << "GetSocket() failed"; } } } // namespace runtime diff --git a/apps/hexagon_launcher/launcher_core.cc b/apps/hexagon_launcher/launcher_core.cc index 0099f7fc79e5..44fc48c92701 100644 --- a/apps/hexagon_launcher/launcher_core.cc +++ b/apps/hexagon_launcher/launcher_core.cc @@ -182,7 +182,7 @@ tvm::runtime::Module load_module(const std::string& file_name) { static const tvm::ffi::Function loader = get_runtime_func("ffi.Module.load_from_file.hexagon"); tvm::ffi::Any rv = loader(file_name); if (rv.type_code() == kTVMModuleHandle) { - ICHECK_EQ(rv.type_code(), kTVMModuleHandle) + TVM_FFI_ICHECK_EQ(rv.type_code(), kTVMModuleHandle) << __func__ << ": loaded " << file_name << ", but did not get module handle"; return rv.operator tvm::runtime::Module(); } diff --git a/apps/hexagon_launcher/launcher_util.cc b/apps/hexagon_launcher/launcher_util.cc index 5524c2f0f338..ddbafb3c84a9 100644 --- a/apps/hexagon_launcher/launcher_util.cc +++ b/apps/hexagon_launcher/launcher_util.cc @@ -42,7 +42,7 @@ size_t get_file_size(std::ifstream&& in_file) { std::string load_text_file(const std::string& file_name) { constexpr size_t block_size = 1024 * 1024; // 1MB std::ifstream in_file(file_name); - ICHECK(in_file.is_open()) << "cannot open file " << file_name; + TVM_FFI_ICHECK(in_file.is_open()) << "cannot open file " << file_name; size_t file_size = get_file_size(in_file); std::string buffer(file_size + 1, 0); @@ -52,7 +52,7 @@ std::string load_text_file(const std::string& file_name) { void* load_binary_file(const std::string& file_name, void* buffer, size_t buffer_size) { std::ifstream in_file(file_name); - ICHECK(in_file.is_open()) << "cannot open file " << file_name; + TVM_FFI_ICHECK(in_file.is_open()) << "cannot open file " << file_name; size_t file_size = get_file_size(in_file); in_file.read(reinterpret_cast(buffer), @@ -62,7 +62,7 @@ void* load_binary_file(const std::string& file_name, void* buffer, size_t buffer void write_binary_file(const std::string& file_name, void* buffer, size_t buffer_size) { std::ofstream out_file(file_name); - ICHECK(out_file.is_open()) << "cannot open file " << file_name; + TVM_FFI_ICHECK(out_file.is_open()) << "cannot open file " << file_name; out_file.write(reinterpret_cast(buffer), buffer_size); } diff --git a/include/tvm/arith/analyzer.h b/include/tvm/arith/analyzer.h index 099643d0a0bb..85d814f5b223 100644 --- a/include/tvm/arith/analyzer.h +++ b/include/tvm/arith/analyzer.h @@ -537,10 +537,10 @@ class TransitiveComparisonAnalyzer { * arith::Analyzer analyzer; * { * With scope(&analyzer, x % 3 == 0); - * ICHECK_EQ(analyzer.modular_set(x)->coeff, 3); + * TVM_FFI_ICHECK_EQ(analyzer.modular_set(x)->coeff, 3); * } * // constraint no longer in effect. - * ICHECK_NE(analyzer.modular_set(x)->coeff, 3); + * TVM_FFI_ICHECK_NE(analyzer.modular_set(x)->coeff, 3); * * \endcode */ diff --git a/include/tvm/ir/attrs.h b/include/tvm/ir/attrs.h index 29371d206957..3da7f8d1c18e 100644 --- a/include/tvm/ir/attrs.h +++ b/include/tvm/ir/attrs.h @@ -384,7 +384,8 @@ template class AttrsNodeReflAdapter : public BaseAttrsNode { public: void InitByPackedArgs(const ffi::PackedArgs& args, bool allow_unknown) final { - LOG(FATAL) << "`" << DerivedType::_type_key << "` uses new reflection mechanism for init"; + TVM_FFI_THROW(InternalError) << "`" << DerivedType::_type_key + << "` uses new reflection mechanism for init"; } private: diff --git a/include/tvm/ir/diagnostic.h b/include/tvm/ir/diagnostic.h index 24553de6c408..2891ab0e7994 100644 --- a/include/tvm/ir/diagnostic.h +++ b/include/tvm/ir/diagnostic.h @@ -65,13 +65,16 @@ class DiagnosticNode : public Object { ObjectRef loc; /*! \brief The diagnostic message. */ ffi::String message; + /*! \brief The error kind when the diagnostic is used as an error (e.g. "TypeError"). */ + ffi::String error_kind{"InternalError"}; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; refl::ObjectDef() .def_ro("level", &DiagnosticNode::level) .def_ro("span", &DiagnosticNode::span) - .def_ro("message", &DiagnosticNode::message); + .def_ro("message", &DiagnosticNode::message) + .def_ro("error_kind", &DiagnosticNode::error_kind); } static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindTreeNode; @@ -81,6 +84,8 @@ class DiagnosticNode : public Object { class Diagnostic : public ObjectRef { public: TVM_DLL Diagnostic(DiagnosticLevel level, Span span, const std::string& message); + TVM_DLL Diagnostic(DiagnosticLevel level, Span span, const std::string& message, + const std::string& error_kind); static DiagnosticBuilder Bug(Span span); static DiagnosticBuilder Error(Span span); @@ -99,6 +104,10 @@ class Diagnostic : public ObjectRef { static DiagnosticBuilder Warning(const Object* loc); static DiagnosticBuilder Note(const Object* loc); static DiagnosticBuilder Help(const Object* loc); + // variants with error kind + static DiagnosticBuilder Error(std::string error_kind, Span span); + static DiagnosticBuilder Error(std::string error_kind, ObjectRef loc); + static DiagnosticBuilder Error(std::string error_kind, const Object* loc); TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(Diagnostic, ObjectRef, DiagnosticNode); }; @@ -122,6 +131,9 @@ class DiagnosticBuilder { */ ObjectRef loc; + /*! \brief The error kind (e.g. "TypeError", "ValueError"). */ + std::string error_kind{"InternalError"}; + template DiagnosticBuilder& operator<<(const T& val) { // NOLINT(*) stream_ << val; @@ -131,13 +143,24 @@ class DiagnosticBuilder { DiagnosticBuilder() : level(DiagnosticLevel::kError), source_name(), span(Span()) {} DiagnosticBuilder(const DiagnosticBuilder& builder) - : level(builder.level), source_name(builder.source_name), span(builder.span) {} + : level(builder.level), + source_name(builder.source_name), + span(builder.span), + error_kind(builder.error_kind) {} DiagnosticBuilder(DiagnosticLevel level, Span span) : level(level), span(span) {} DiagnosticBuilder(DiagnosticLevel level, ObjectRef loc) : level(level), loc(loc) {} - operator Diagnostic() { return Diagnostic(this->level, this->span, this->stream_.str()); } + /*! \brief Set the error kind for this diagnostic. */ + DiagnosticBuilder& WithErrorKind(std::string kind) { + error_kind = std::move(kind); + return *this; + } + + operator Diagnostic() { + return Diagnostic(this->level, this->span, this->stream_.str(), this->error_kind); + } private: std::stringstream stream_; @@ -178,7 +201,7 @@ class DiagnosticRenderer : public ObjectRef { void Render(const DiagnosticContext& ctx); DiagnosticRendererNode* operator->() { - ICHECK(get() != nullptr); + TVM_FFI_ICHECK(get() != nullptr); return static_cast(get_mutable()); } @@ -231,7 +254,7 @@ class DiagnosticContext : public ObjectRef { void Render(); DiagnosticContextNode* operator->() { - ICHECK(get() != nullptr); + TVM_FFI_ICHECK(get() != nullptr); return static_cast(get_mutable()); } diff --git a/include/tvm/ir/env_func.h b/include/tvm/ir/env_func.h index c0735b7cd69f..264198333e6f 100644 --- a/include/tvm/ir/env_func.h +++ b/include/tvm/ir/env_func.h @@ -83,7 +83,7 @@ class EnvFunc : public ObjectRef { template ffi::Any operator()(Args&&... args) const { const EnvFuncNode* n = operator->(); - ICHECK(n != nullptr); + TVM_FFI_ICHECK(n != nullptr); return n->func(std::forward(args)...); } /*! @@ -141,7 +141,7 @@ class TypedEnvFunc : public ObjectRef { */ R operator()(Args... args) const { const EnvFuncNode* n = operator->(); - ICHECK(n != nullptr); + TVM_FFI_ICHECK(n != nullptr); if constexpr (std::is_same_v) { n->func(std::forward(args)...); } else { diff --git a/include/tvm/ir/expr.h b/include/tvm/ir/expr.h index fd2e0e6a5145..faf2c18c1cac 100644 --- a/include/tvm/ir/expr.h +++ b/include/tvm/ir/expr.h @@ -638,7 +638,7 @@ class Integer : public IntImm { * \brief convert to int64_t */ int64_t IntValue() const { - ICHECK(data_ != nullptr) << " Trying to reference a null Integer"; + TVM_FFI_ICHECK(data_ != nullptr) << " Trying to reference a null Integer"; return (*this)->value; } // comparators diff --git a/include/tvm/ir/module.h b/include/tvm/ir/module.h index 17369dbab665..becd19ed70be 100644 --- a/include/tvm/ir/module.h +++ b/include/tvm/ir/module.h @@ -280,7 +280,7 @@ class IRModule : public ObjectRef { /*! \return mutable pointers to the node. */ IRModuleNode* operator->() const { auto* ptr = get_mutable(); - ICHECK(ptr != nullptr); + TVM_FFI_ICHECK(ptr != nullptr); return static_cast(ptr); } diff --git a/include/tvm/ir/name_supply.h b/include/tvm/ir/name_supply.h index d3139ea2c821..50dd2567d2ca 100644 --- a/include/tvm/ir/name_supply.h +++ b/include/tvm/ir/name_supply.h @@ -151,7 +151,8 @@ class NameSupply : public ObjectRef { // name = {O = others}{D = consecutive digits} // let O -> prefix; std::string prefix = name.substr(0, idx_last_first_num); - ICHECK(prefix.size() > 0 && std::isalpha(prefix[0])) << "Invalid variable name: " << name; + TVM_FFI_ICHECK(prefix.size() > 0 && std::isalpha(prefix[0])) + << "Invalid variable name: " << name; if (0 == name_map.count(prefix)) name_map[prefix] = 0; if (idx_last_first_num < name.size()) { // has some digits. // let D's nearest natural number -> idx; diff --git a/include/tvm/ir/op.h b/include/tvm/ir/op.h index 2153ce4190d4..9171a9e6d2df 100644 --- a/include/tvm/ir/op.h +++ b/include/tvm/ir/op.h @@ -364,7 +364,7 @@ inline OpRegEntry& OpRegEntry::set_support_level(int32_t n) { // NOLINT(*) template inline OpRegEntry& OpRegEntry::set_attr( // NOLINT(*) const std::string& attr_name, const ValueType& value, int plevel) { - ICHECK_GT(plevel, 0) << "plevel in set_attr must be greater than 0"; + TVM_FFI_ICHECK_GT(plevel, 0) << "plevel in set_attr must be greater than 0"; UpdateAttr(attr_name, Any(value), plevel); return *this; } @@ -373,7 +373,7 @@ inline OpRegEntry& OpRegEntry::set_attr( // NOLINT(*) template inline ValueType OpAttrMap::get(const RelaxExpr& expr, ValueType def_value) const { - ICHECK(expr.defined()); + TVM_FFI_ICHECK(expr.defined()); if (const OpNode* op = expr.as()) { return this->map_.get(ffi::GetRef(op), def_value); } else { diff --git a/include/tvm/ir/source_map.h b/include/tvm/ir/source_map.h index c94fb6b0a120..60a30ffe1709 100644 --- a/include/tvm/ir/source_map.h +++ b/include/tvm/ir/source_map.h @@ -214,7 +214,7 @@ class SourceMap : public ObjectRef { void Add(const Source& source); SourceMapObj* operator->() { - ICHECK(get() != nullptr); + TVM_FFI_ICHECK(get() != nullptr); return static_cast(get_mutable()); } diff --git a/include/tvm/ir/transform.h b/include/tvm/ir/transform.h index 77d90a0e9558..97c98ccbf4d6 100644 --- a/include/tvm/ir/transform.h +++ b/include/tvm/ir/transform.h @@ -166,7 +166,7 @@ class PassContext : public ObjectRef { * \return const access pointer. */ const PassContextNode* operator->() const { - ICHECK(get() != nullptr); + TVM_FFI_ICHECK(get() != nullptr); return static_cast(get()); } /*! @@ -174,7 +174,7 @@ class PassContext : public ObjectRef { * \return mutable access pointer. */ PassContextNode* operator->() { - ICHECK(get() != nullptr); + TVM_FFI_ICHECK(get() != nullptr); return static_cast(get_mutable()); } diff --git a/include/tvm/ir/type_functor.h b/include/tvm/ir/type_functor.h index b2878519c424..ae9c8f1fb080 100644 --- a/include/tvm/ir/type_functor.h +++ b/include/tvm/ir/type_functor.h @@ -70,7 +70,7 @@ class TypeFunctor { * \return The result of the call */ virtual R VisitType(const Type& n, Args... args) { - ICHECK(n.defined()); + TVM_FFI_ICHECK(n.defined()); static FType vtable = InitVTable(); return vtable(n, this, std::forward(args)...); } @@ -80,7 +80,7 @@ class TypeFunctor { virtual R VisitType_(const PrimTypeNode* op, Args... args) TYPE_FUNCTOR_DEFAULT; virtual R VisitType_(const PointerTypeNode* op, Args... args) TYPE_FUNCTOR_DEFAULT; virtual R VisitTypeDefault_(const Object* op, Args...) { - LOG(FATAL) << "Do not have a default for " << op->GetTypeKey(); + TVM_FFI_THROW(InternalError) << "Do not have a default for " << op->GetTypeKey(); throw; // unreachable, written to stop compiler warning } diff --git a/include/tvm/node/attr_registry_map.h b/include/tvm/node/attr_registry_map.h index e273fa8f5fe1..9b689304ca69 100644 --- a/include/tvm/node/attr_registry_map.h +++ b/include/tvm/node/attr_registry_map.h @@ -56,9 +56,9 @@ class AttrRegistryMapContainerMap { * \return the const reference to the content value. */ const ffi::Any& operator[](const KeyType& key) const { - ICHECK(key.defined()); + TVM_FFI_ICHECK(key.defined()); const uint32_t idx = key->AttrRegistryIndex(); - ICHECK(idx < data_.size() && data_[idx].second != 0) + TVM_FFI_ICHECK(idx < data_.size() && data_[idx].second != 0) << "Attribute " << attr_name_ << " has not been registered for " << key->name; return data_[idx].first; } @@ -71,7 +71,7 @@ class AttrRegistryMapContainerMap { */ template ValueType get(const KeyType& key, ValueType def_value) const { - ICHECK(key.defined()); + TVM_FFI_ICHECK(key.defined()); const uint32_t idx = key->AttrRegistryIndex(); if (idx < data_.size() && data_[idx].second != 0) { if constexpr (std::is_same_v) { diff --git a/include/tvm/node/functor.h b/include/tvm/node/functor.h index 82ea37566eb5..b9a468153226 100644 --- a/include/tvm/node/functor.h +++ b/include/tvm/node/functor.h @@ -97,8 +97,8 @@ class NodeFunctor { * \return The result. */ R operator()(const ObjectRef& n, Args... args) const { - ICHECK(can_dispatch(n)) << "NodeFunctor calls un-registered function on type " - << n->GetTypeKey(); + TVM_FFI_ICHECK(can_dispatch(n)) + << "NodeFunctor calls un-registered function on type " << n->GetTypeKey(); return (*func_[n->type_index() - begin_type_index_])(n, std::forward(args)...); } /*! @@ -113,8 +113,9 @@ class NodeFunctor { if (func_.size() <= tindex) { func_.resize(tindex + 1, nullptr); } - ICHECK(func_[tindex] == nullptr) << "Dispatch for " << TNode::_type_key << " is already set"; - ICHECK_EQ(begin_type_index_, 0) << " Cannot call set_dispatch after calling Finalize"; + TVM_FFI_ICHECK(func_[tindex] == nullptr) + << "Dispatch for " << TNode::_type_key << " is already set"; + TVM_FFI_ICHECK_EQ(begin_type_index_, 0) << " Cannot call set_dispatch after calling Finalize"; func_[tindex] = f; return *this; } @@ -127,8 +128,8 @@ class NodeFunctor { template TSelf& clear_dispatch() { // NOLINT(*) uint32_t tindex = TNode::RuntimeTypeIndex(); - ICHECK_LT(tindex, func_.size()) << "clear_dispatch: index out of range"; - ICHECK_EQ(begin_type_index_, 0) << " Cannot call clear_dispatch after calling Finalize"; + TVM_FFI_ICHECK_LT(tindex, func_.size()) << "clear_dispatch: index out of range"; + TVM_FFI_ICHECK_EQ(begin_type_index_, 0) << " Cannot call clear_dispatch after calling Finalize"; func_[tindex] = nullptr; return *this; } @@ -138,7 +139,7 @@ class NodeFunctor { * and optimize the space of the func table so it is more compact */ void Finalize() { - ICHECK_EQ(begin_type_index_, 0) << "Can only call Finalize once"; + TVM_FFI_ICHECK_EQ(begin_type_index_, 0) << "Can only call Finalize once"; while (begin_type_index_ < func_.size() && func_[begin_type_index_] == nullptr) { ++begin_type_index_; } diff --git a/include/tvm/node/repr_printer.h b/include/tvm/node/repr_printer.h index f3e0edab6e07..1c1f1ec38099 100644 --- a/include/tvm/node/repr_printer.h +++ b/include/tvm/node/repr_printer.h @@ -118,7 +118,7 @@ inline std::ostream& operator<<(std::ostream& os, const AccessStep& step) { return os; } default: { - LOG(FATAL) << "Unknown access step kind: " << static_cast(step->kind); + TVM_FFI_THROW(InternalError) << "Unknown access step kind: " << static_cast(step->kind); } } return os; diff --git a/include/tvm/relax/binding_rewrite.h b/include/tvm/relax/binding_rewrite.h index 90d5b1540ee0..9b734fb7f5c8 100644 --- a/include/tvm/relax/binding_rewrite.h +++ b/include/tvm/relax/binding_rewrite.h @@ -102,7 +102,7 @@ class DataflowBlockRewrite : public ObjectRef { * \return mutable access pointer. */ DataflowBlockRewriteNode* operator->() { - ICHECK(get() != nullptr); + TVM_FFI_ICHECK(get() != nullptr); return static_cast(get_mutable()); } diff --git a/include/tvm/relax/dataflow_pattern.h b/include/tvm/relax/dataflow_pattern.h index 1925d5ae148d..492b53bf596b 100644 --- a/include/tvm/relax/dataflow_pattern.h +++ b/include/tvm/relax/dataflow_pattern.h @@ -279,12 +279,12 @@ class PatternContext : public ObjectRef { TVM_DLL explicit PatternContext(bool incremental = false); const PatternContextNode* operator->() const { - ICHECK(get() != nullptr); + TVM_FFI_ICHECK(get() != nullptr); return static_cast(get()); } PatternContextNode* operator->() { - ICHECK(get() != nullptr); + TVM_FFI_ICHECK(get() != nullptr); return static_cast(get_mutable()); } @@ -303,7 +303,7 @@ class PatternContext : public ObjectRef { pairs.emplace_back(consumer, std::vector{cons}); } else { auto& vec = it->second; - ICHECK(std::find(vec.cbegin(), vec.cend(), cons) == vec.cend()) + TVM_FFI_ICHECK(std::find(vec.cbegin(), vec.cend(), cons) == vec.cend()) << "Constraint already exists"; vec.push_back(cons); } diff --git a/include/tvm/relax/dataflow_pattern_functor.h b/include/tvm/relax/dataflow_pattern_functor.h index c12ab0326df4..89098546391a 100644 --- a/include/tvm/relax/dataflow_pattern_functor.h +++ b/include/tvm/relax/dataflow_pattern_functor.h @@ -76,7 +76,7 @@ class DFPatternFunctor { * \return The result of the call */ virtual R VisitDFPattern(const DFPattern& n, Args... args) { - ICHECK(n.defined()); + TVM_FFI_ICHECK(n.defined()); static FType vtable = InitVTable(); return vtable(n, this, std::forward(args)...); } @@ -109,7 +109,7 @@ class DFPatternFunctor { Args... args) DFPATTERN_FUNCTOR_DEFAULT; virtual R VisitDFPatternDefault_(const Object* op, Args...) { - LOG(FATAL) << "Do not have a default for " << op->GetTypeKey(); + TVM_FFI_THROW(InternalError) << "Do not have a default for " << op->GetTypeKey(); throw; } diff --git a/include/tvm/relax/distributed/axis_group_graph.h b/include/tvm/relax/distributed/axis_group_graph.h index 26a6ab228c52..6ea322938f06 100644 --- a/include/tvm/relax/distributed/axis_group_graph.h +++ b/include/tvm/relax/distributed/axis_group_graph.h @@ -213,7 +213,7 @@ struct Axis { Axis(const ExprNode* tensor, int dim, int tuple_index = 0) : tensor(tensor), dim(dim), tuple_index(tuple_index) { - ICHECK(tensor->IsInstance() || tensor->IsInstance()); + TVM_FFI_ICHECK(tensor->IsInstance() || tensor->IsInstance()); } bool operator==(const Axis& other) const { @@ -284,7 +284,7 @@ class AxisGroupGraph { case EdgeType::kSimbling: return EdgeType::kSimbling; } - LOG(FATAL) << "Unreachable code"; + TVM_FFI_THROW(InternalError) << "Unreachable code"; throw; } @@ -297,7 +297,7 @@ class AxisGroupGraph { case EdgeType::kSimbling: return 1; } - LOG(FATAL) << "Unreachable code"; + TVM_FFI_THROW(InternalError) << "Unreachable code"; throw; } @@ -439,8 +439,9 @@ class AxisGroupGraph { it++; } } - ICHECK(specs.size() == 1) << "multiple possible sharding for axis: (" - << ffi::GetRef(axis.tensor) << ", " << axis.dim << ")"; + TVM_FFI_ICHECK(specs.size() == 1) + << "multiple possible sharding for axis: (" << ffi::GetRef(axis.tensor) << ", " + << axis.dim << ")"; } } diff --git a/include/tvm/relax/expr_functor.h b/include/tvm/relax/expr_functor.h index afacb81e4072..9079be8f329b 100644 --- a/include/tvm/relax/expr_functor.h +++ b/include/tvm/relax/expr_functor.h @@ -126,8 +126,9 @@ class ExprFunctor { * \return The result of the call */ virtual R VisitExpr(const Expr& n, Args... args) { - ICHECK(n.defined()) << "Found null pointer node while traversing AST. The previous pass may " - "have generated invalid data."; + TVM_FFI_ICHECK(n.defined()) + << "Found null pointer node while traversing AST. The previous pass may " + "have generated invalid data."; static FType vtable = InitVTable(); return vtable(n, this, std::forward(args)...); } @@ -151,7 +152,7 @@ class ExprFunctor { virtual R VisitExpr_(const StringImmNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; virtual R VisitExpr_(const DataTypeImmNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; virtual R VisitExprDefault_(const Object* op, Args...) { - LOG(FATAL) << "Do not have a default for " << op->GetTypeKey(); + TVM_FFI_THROW(InternalError) << "Do not have a default for " << op->GetTypeKey(); throw; } diff --git a/include/tvm/relax/nested_msg.h b/include/tvm/relax/nested_msg.h index 77f001630f75..20495e00102b 100644 --- a/include/tvm/relax/nested_msg.h +++ b/include/tvm/relax/nested_msg.h @@ -182,7 +182,7 @@ class NestedMsg { * \note This function checks if the msg is leaf. */ T LeafValue() const { - ICHECK(IsLeaf()); + TVM_FFI_ICHECK(IsLeaf()); return ffi::details::AnyUnsafe::CopyFromAnyViewAfterCheck(data_); } @@ -362,7 +362,7 @@ TargetType NestedMsgTo(NestedMsg msg, FMapLeaf fmapleaf, FCombine fcombine) { } else if (msg.IsLeaf()) { return fmapleaf(msg.LeafValue()); } else { - ICHECK(msg.IsNested()); + TVM_FFI_ICHECK(msg.IsNested()); ffi::Array> arr = msg.NestedArray(); ffi::Array subexpr; subexpr.reserve(arr.size()); @@ -401,7 +401,7 @@ Expr NestedMsgToExpr(NestedMsg msg, FType fmapleaf) { simplified_flag &= (simplified_tuple == node->tuple); } else { simplified_tuple = node->tuple; - ICHECK(simplified_tuple.defined()); + TVM_FFI_ICHECK(simplified_tuple.defined()); } } } @@ -432,14 +432,14 @@ NestedMsg CombineNestedMsg(NestedMsg lhs, NestedMsg rhs, FType fcombine if (rhs.IsNull()) return lhs; if (lhs.IsLeaf()) { - ICHECK(rhs.IsLeaf()) << "Cannot combine leaf with nested"; + TVM_FFI_ICHECK(rhs.IsLeaf()) << "Cannot combine leaf with nested"; return NestedMsg(fcombine(lhs.LeafValue(), rhs.LeafValue())); } else { - ICHECK(lhs.IsNested()); - ICHECK(rhs.IsNested()) << "Cannot combine leaf with nested"; + TVM_FFI_ICHECK(lhs.IsNested()); + TVM_FFI_ICHECK(rhs.IsNested()) << "Cannot combine leaf with nested"; ffi::Array> arr_lhs = lhs.NestedArray(); ffi::Array> arr_rhs = rhs.NestedArray(); - ICHECK_EQ(arr_lhs.size(), arr_rhs.size()) + TVM_FFI_ICHECK_EQ(arr_lhs.size(), arr_rhs.size()) << "Cannot combine two nested array with different sizes"; ffi::Array> res; res.reserve(arr_lhs.size()); @@ -465,7 +465,7 @@ NestedMsg MapNestedMsg(NestedMsg msg, FType fmapleaf) { } else if (msg.IsLeaf()) { return fmapleaf(msg.LeafValue()); } else { - ICHECK(msg.IsNested()); + TVM_FFI_ICHECK(msg.IsNested()); ffi::Array> arr = msg.NestedArray(); ffi::Array> res; res.reserve(arr.size()); @@ -492,9 +492,10 @@ NestedMsg MapNestedMsg(NestedMsg msg, FType fmapleaf) { template void DecomposeNestedMsg(Expr expr, NestedMsg msg, FType fvisitleaf) { if (auto* tuple = expr.as()) { - ICHECK(msg.IsNested()) << "Expected nested to match tuple"; + TVM_FFI_ICHECK(msg.IsNested()) << "Expected nested to match tuple"; ffi::Array> arr = msg.NestedArray(); - ICHECK_EQ(arr.size(), tuple->fields.size()) << "Expected nested array size to match tuple size"; + TVM_FFI_ICHECK_EQ(arr.size(), tuple->fields.size()) + << "Expected nested array size to match tuple size"; for (size_t i = 0; i < arr.size(); ++i) { DecomposeNestedMsg(tuple->fields[i], arr[i], fvisitleaf); } @@ -523,7 +524,7 @@ Expr TransformTupleLeaf(Expr expr, std::array, N> msgs, FType ftran if (const auto* tuple = sinfo.as()) { std::array>, N> msg_arrays; for (size_t i = 0; i < N; ++i) { - ICHECK(msgs[i].IsNested()) << "Expected nested to match tuple"; + TVM_FFI_ICHECK(msgs[i].IsNested()) << "Expected nested to match tuple"; msg_arrays[i] = msgs[i].NestedArray(); } bool same = true; @@ -546,7 +547,7 @@ Expr TransformTupleLeaf(Expr expr, std::array, N> msgs, FType ftran return same ? expr : Tuple(fields); } else { for (const auto& msg : msgs) { - ICHECK(msg.IsLeaf()) << "Expected leaf to match non-tuple"; + TVM_FFI_ICHECK(msg.IsLeaf()) << "Expected leaf to match non-tuple"; } return ftransleaf(expr, msgs); } @@ -572,7 +573,7 @@ StructInfo TransformTupleLeaf(StructInfo sinfo, std::array, N> msgs if (const auto* tuple = sinfo.as()) { std::array>, N> msg_arrays; for (size_t i = 0; i < N; ++i) { - ICHECK(msgs[i].IsNested()) << "Expected nested to match tuple"; + TVM_FFI_ICHECK(msgs[i].IsNested()) << "Expected nested to match tuple"; msg_arrays[i] = msgs[i].NestedArray(); } bool same = true; @@ -590,7 +591,7 @@ StructInfo TransformTupleLeaf(StructInfo sinfo, std::array, N> msgs return same ? sinfo : TupleStructInfo(fields); } else { for (const auto& msg : msgs) { - ICHECK(msg.IsLeaf()) << "Expected leaf to match non-tuple"; + TVM_FFI_ICHECK(msg.IsLeaf()) << "Expected leaf to match non-tuple"; } return ftransleaf(sinfo, msgs); } diff --git a/include/tvm/relax/struct_info.h b/include/tvm/relax/struct_info.h index f08d737fdca5..12b97e20c21d 100644 --- a/include/tvm/relax/struct_info.h +++ b/include/tvm/relax/struct_info.h @@ -381,7 +381,7 @@ inline ffi::Optional MatchStructInfo(const Expr& expr) { */ template inline const T* GetStructInfoAs(const Expr& expr) { - ICHECK(expr->struct_info_.defined()) + TVM_FFI_ICHECK(expr->struct_info_.defined()) << "The struct_info is not populated, check if you have normalized the expr"; return expr->struct_info_.as(); } @@ -394,7 +394,7 @@ inline const T* GetStructInfoAs(const Expr& expr) { */ inline StructInfo GetStructInfo(const Expr& expr) { auto* ptr = expr->struct_info_.as(); - ICHECK(ptr) << "The struct_info is not populated, check if you have normalized the expr"; + TVM_FFI_ICHECK(ptr) << "The struct_info is not populated, check if you have normalized the expr"; return ffi::GetRef(ptr); } diff --git a/include/tvm/relax/struct_info_functor.h b/include/tvm/relax/struct_info_functor.h index 2ce562754791..e8ba7a80299e 100644 --- a/include/tvm/relax/struct_info_functor.h +++ b/include/tvm/relax/struct_info_functor.h @@ -72,7 +72,7 @@ class StructInfoFunctor { * \return The result of the call */ virtual R VisitStructInfo(const StructInfo& n, Args... args) { - ICHECK(n.defined()); + TVM_FFI_ICHECK(n.defined()); static FStructInfo vtable = InitVTable(); return vtable(n, this, std::forward(args)...); } @@ -92,7 +92,7 @@ class StructInfoFunctor { virtual R VisitStructInfo_(const FuncStructInfoNode* op, Args... args) STRUCT_INFO_FUNCTOR_DEFAULT; virtual R VisitStructInfoDefault_(const Object* op, Args...) { - LOG(FATAL) << "Do not have a default for " << op->GetTypeKey(); + TVM_FFI_THROW(InternalError) << "Do not have a default for " << op->GetTypeKey(); throw; // unreachable, written to stop compiler warning } diff --git a/include/tvm/runtime/data_type.h b/include/tvm/runtime/data_type.h index 0c698334ac6d..67fe50350d2f 100644 --- a/include/tvm/runtime/data_type.h +++ b/include/tvm/runtime/data_type.h @@ -92,22 +92,22 @@ class DataType { data_.code = static_cast(code); data_.bits = static_cast(bits); if (is_scalable) { - ICHECK(lanes > 1) << "Invalid value for vscale factor" << lanes; + TVM_FFI_ICHECK(lanes > 1) << "Invalid value for vscale factor" << lanes; } data_.lanes = is_scalable ? static_cast(-lanes) : static_cast(lanes); if (code == kBFloat) { - ICHECK_EQ(bits, 16); + TVM_FFI_ICHECK_EQ(bits, 16); } if (code == kFloat8_e3m4 || code == kFloat8_e4m3 || code == kFloat8_e4m3b11fnuz || code == kFloat8_e4m3fn || code == kFloat8_e4m3fnuz || code == kFloat8_e5m2 || code == kFloat8_e5m2fnuz || code == kFloat8_e8m0fnu) { - ICHECK_EQ(bits, 8); + TVM_FFI_ICHECK_EQ(bits, 8); } if (code == kFloat6_e2m3fn || code == kFloat6_e3m2fn) { - ICHECK_EQ(bits, 6); + TVM_FFI_ICHECK_EQ(bits, 6); } if (code == kFloat4_e2m1fn) { - ICHECK_EQ(bits, 4); + TVM_FFI_ICHECK_EQ(bits, 4); } } /*! \return The type code. */ @@ -120,7 +120,8 @@ class DataType { int lanes() const { int lanes_as_int = static_cast(data_.lanes); if (lanes_as_int < 0) { - LOG(FATAL) << "Can't fetch the lanes of a scalable vector at a compile time."; + TVM_FFI_THROW(InternalError) + << "Can't fetch the lanes of a scalable vector at a compile time."; } return lanes_as_int; } @@ -128,7 +129,7 @@ class DataType { int vscale_factor() const { int lanes_as_int = static_cast(data_.lanes); if (lanes_as_int >= -1) { - LOG(FATAL) << "A fixed length vector doesn't have a vscale factor."; + TVM_FFI_THROW(InternalError) << "A fixed length vector doesn't have a vscale factor."; } return -lanes_as_int; } @@ -427,7 +428,7 @@ inline int GetVectorBytes(DataType dtype) { dtype == DataType::Float6E2M3FN() || dtype == DataType::Float6E3M2FN()) { return 1; } - ICHECK_EQ(data_bits % 8, 0U) << "Need to load/store by multiple of bytes"; + TVM_FFI_ICHECK_EQ(data_bits % 8, 0U) << "Need to load/store by multiple of bytes"; return data_bits / 8; } diff --git a/include/tvm/runtime/device_api.h b/include/tvm/runtime/device_api.h index f14b22c57628..7139d41cbbb4 100644 --- a/include/tvm/runtime/device_api.h +++ b/include/tvm/runtime/device_api.h @@ -345,7 +345,7 @@ inline const char* DLDeviceType2Str(int type) { case kDLHexagon: return "hexagon"; default: - LOG(FATAL) << "unknown type = " << type; + TVM_FFI_THROW(InternalError) << "unknown type = " << type; } throw; } @@ -364,7 +364,7 @@ inline bool IsRPCSessionDevice(Device dev) { return (dev.device_type / kRPCSessM * \return the table index. */ inline int GetRPCSessionIndex(Device dev) { - ICHECK(IsRPCSessionDevice(dev)) << "GetRPCSessionIndex: dev has no RPC session"; + TVM_FFI_ICHECK(IsRPCSessionDevice(dev)) << "GetRPCSessionIndex: dev has no RPC session"; return dev.device_type / kRPCSessMask - 1; } @@ -397,8 +397,8 @@ inline std::ostream& operator<<(std::ostream& os, DLDevice dev) { // NOLINT(*) * \return A Device with RPC session mask added, valid on the RPC client. */ inline Device AddRPCSessionMask(Device dev, int session_table_index) { - CHECK(!IsRPCSessionDevice(dev)) << "AddRPCSessionMask: dev already non-zero RPCSessionIndex: " - << dev; + TVM_FFI_ICHECK(!IsRPCSessionDevice(dev)) + << "AddRPCSessionMask: dev already non-zero RPCSessionIndex: " << dev; dev.device_type = static_cast(dev.device_type | (kRPCSessMask * (session_table_index + 1))); return dev; diff --git a/include/tvm/runtime/disco/builtin.h b/include/tvm/runtime/disco/builtin.h index ae119e52652b..3a1e4850a427 100644 --- a/include/tvm/runtime/disco/builtin.h +++ b/include/tvm/runtime/disco/builtin.h @@ -53,7 +53,7 @@ inline std::string ReduceKind2String(ReduceKind kind) { case ReduceKind::kAvg: return "kAvg"; } - LOG(FATAL) << "ValueError: Unknown ReduceKind: " << static_cast(kind); + TVM_FFI_THROW(ValueError) << "Unknown ReduceKind: " << static_cast(kind); } /*! diff --git a/include/tvm/runtime/disco/session.h b/include/tvm/runtime/disco/session.h index 756312777c39..758feaf1f4bc 100644 --- a/include/tvm/runtime/disco/session.h +++ b/include/tvm/runtime/disco/session.h @@ -123,7 +123,7 @@ inline std::string DiscoAction2String(DiscoAction action) { case DiscoAction::kDebugSetRegister: return "kDebugSetRegister"; } - LOG(FATAL) << "ValueError: Unknown DiscoAction: " << static_cast(action); + TVM_FFI_THROW(ValueError) << "Unknown DiscoAction: " << static_cast(action); } class SessionObj; diff --git a/include/tvm/runtime/logging.h b/include/tvm/runtime/logging.h index b34b885c042f..d4e2f6f180cf 100644 --- a/include/tvm/runtime/logging.h +++ b/include/tvm/runtime/logging.h @@ -19,7 +19,20 @@ /*! * \file tvm/runtime/logging.h - * \brief logging utilitiesß + * \brief logging utilities + * + * We use the following facilities from tvm/ffi/error.h for + * error handling and checking: + * + * - TVM_FFI_THROW(ErrorKind) << "msg"; + * - TVM_FFI_CHECK(cond, ErrorKind) << "msg"; + * - TVM_FFI_CHECK_EQ(x, y, ErrorKind) << "msg"; + * - TVM_FFI_ICHECK(x) << "msg"; // InternalError + * - TVM_FFI_ICHECK_EQ(x, y) << "msg"; + * - TVM_FFI_DCHECK(x) << "msg"; // Debug-only InternalError + * + * LOG(INFO), LOG(WARNING), LOG(ERROR) are kept for logging. + * LOG(FATAL) is kept for completeness, it throws InternalError. */ #ifndef TVM_RUNTIME_LOGGING_H_ #define TVM_RUNTIME_LOGGING_H_ @@ -30,7 +43,6 @@ #include #include #include -#include #include #include #include @@ -91,92 +103,6 @@ #define TVM_LOG_CUSTOMIZE 0 #endif -// a technique that enables overriding macro names on the number of parameters. This is used -// to define other macros below -#define GET_MACRO(_1, _2, _3, _4, _5, NAME, ...) NAME - -/*! - * \brief COND_X calls COND_X_N where N is the number of parameters passed to COND_X - * X can be any of CHECK_GE, CHECK_EQ, CHECK, or LOG COND_X (but not COND_X_N) - * are supposed to be used outside this file. - * The first parameter of COND_X (and therefore, COND_X_N), which we call 'quit_on_assert', - * is a boolean. The rest of the parameters of COND_X is the same as the parameters of X. - * quit_on_assert determines the overall behavior of COND_X. If it's true COND_X - * quits the program on assertion failure. If it's false, then it moves on and somehow reports - * the assertion failure back to the macro caller in an appropriate manner (e.g, 'return false' - * in a function, or 'continue' or 'break' in a loop) - * The default behavior when quit_on_assertion is false, is to 'return false'. If this is not - * desirable, the macro caller can pass one more last parameter to COND_X to tell COND_X what - * to do when quit_on_assertion is false and the assertion fails. - * - * Rationale: These macros were designed to implement functions that have two behaviors - * in a concise way. Those behaviors are quitting on assertion failures, or trying to - * move on from assertion failures. Note that these macros hide lots of control flow in them, - * and therefore, makes the logic of the whole code slightly harder to understand. However, - * in pieces of code that use these macros frequently, it will significantly shorten the - * amount of code needed to be read, and we won't need to clutter the main logic of the - * function by repetitive control flow structure. The first problem - * mentioned will be improved over time as the developer gets used to the macro. - * - * Here is an example of how to use it - * \code - * bool f(..., bool quit_on_assertion) { - * int a = 0, b = 0; - * ... - * a = ... - * b = ... - * // if quit_on_assertion is true, if a==b, continue, otherwise quit. - * // if quit_on_assertion is false, if a==b, continue, otherwise 'return false' - * // (default behaviour) - * COND_CHECK_EQ(quit_on_assertion, a, b) << "some error message when quiting" - * ... - * for (int i = 0; i < N; i++) { - * a = ... - * b = ... - * // if quit_on_assertion is true, if a==b, continue, otherwise quit. - * // if quit_on_assertion is false, if a==b, continue, otherwise 'break' - * // (non-default behaviour, therefore, has to be explicitly specified) - * COND_CHECK_EQ(quit_on_assertion, a, b, break) << "some error message when quiting" - * } - * } - * \endcode - */ -#define COND_CHECK_GE(...) \ - GET_MACRO(__VA_ARGS__, COND_CHECK_GE_5, COND_CHECK_GE_4, COND_CHECK_GE_3)(__VA_ARGS__) -#define COND_CHECK_EQ(...) \ - GET_MACRO(__VA_ARGS__, COND_CHECK_EQ_5, COND_CHECK_EQ_4, COND_CHECK_EQ_3)(__VA_ARGS__) -#define COND_CHECK(...) \ - GET_MACRO(__VA_ARGS__, COND_CHECK_5, COND_CHECK_4, COND_CHECK_3, COND_CHECK_2)(__VA_ARGS__) -#define COND_LOG(...) \ - GET_MACRO(__VA_ARGS__, COND_LOG_5, COND_LOG_4, COND_LOG_3, COND_LOG_2)(__VA_ARGS__) - -// Not supposed to be used by users directly. -#define COND_CHECK_OP(quit_on_assert, x, y, what, op) \ - if (!quit_on_assert) { \ - if (!((x)op(y))) what; \ - } else /* NOLINT(*) */ \ - CHECK_##op(x, y) - -#define COND_CHECK_EQ_4(quit_on_assert, x, y, what) COND_CHECK_OP(quit_on_assert, x, y, what, ==) -#define COND_CHECK_GE_4(quit_on_assert, x, y, what) COND_CHECK_OP(quit_on_assert, x, y, what, >=) - -#define COND_CHECK_3(quit_on_assert, x, what) \ - if (!quit_on_assert) { \ - if (!(x)) what; \ - } else /* NOLINT(*) */ \ - CHECK(x) - -#define COND_LOG_3(quit_on_assert, x, what) \ - if (!quit_on_assert) { \ - what; \ - } else /* NOLINT(*) */ \ - LOG(x) - -#define COND_CHECK_EQ_3(quit_on_assert, x, y) COND_CHECK_EQ_4(quit_on_assert, x, y, return false) -#define COND_CHECK_GE_3(quit_on_assert, x, y) COND_CHECK_GE_4(quit_on_assert, x, y, return false) -#define COND_CHECK_2(quit_on_assert, x) COND_CHECK_3(quit_on_assert, x, return false) -#define COND_LOG_2(quit_on_assert, x) COND_LOG_3(quit_on_assert, x, return false) - namespace tvm { namespace runtime { @@ -184,54 +110,16 @@ using ffi::EnvErrorAlreadySet; using ffi::Error; /*! - * \brief Error type for errors from CHECK, ICHECK, and LOG(FATAL). This error + * \brief Error type for errors from LOG(FATAL). This error * contains a backtrace of where it occurred. + * + * \note LOG(FATAL) always throws InternalError. For typed errors, + * use TVM_FFI_THROW(ErrorKind) instead. */ class InternalError : public Error { public: - /*! \brief Construct an error. Not recommended to use directly. Instead use LOG(FATAL). - * - * \param file The file where the error occurred. - * \param lineno The line number where the error occurred. - * \param message The error message to display. - * \param time The time at which the error occurred. This should be in local time. - * \param backtrace Backtrace from when the error occurred. - */ InternalError(std::string file, int lineno, std::string message) - : Error(DetectKind(message), DetectMessage(message), - TVMFFIBacktrace(file.c_str(), lineno, "", 0)) {} - - private: - // try to detect the kind of error from the message when the error type - // is folded into the text message - static std::string DetectKind(const std::string& message) { - size_t pos = message.find("Error:"); - if (pos != std::string::npos) { - size_t end = pos + 6; - size_t begin = pos; - for (; begin != 0 && message[begin - 1] != ' '; --begin) { - } - return message.substr(begin, end - begin - 1); - } else { - return "InternalError"; - } - } - - static std::string DetectMessage(const std::string& message) { - size_t pos = message.find("Error:"); - if (pos != std::string::npos) { - size_t end = pos + 6; - size_t begin = pos; - for (; begin != 0 && message[begin - 1] != ' '; --begin) { - } - if (end < message.size() && message[end] == ' ') { - end += 1; - } - return message.substr(0, begin) + message.substr(end); - } else { - return message; - } - } + : Error("InternalError", std::move(message), TVMFFIBacktrace(file.c_str(), lineno, "", 0)) {} }; /*! \brief Internal implementation */ @@ -497,48 +385,6 @@ class VLogContextEntry { std::stringstream sstream_; }; -template -std::unique_ptr LogCheckFormat(const X& x, const Y& y) { - std::ostringstream os; - os << " (" << x << " vs. " << y << ") "; // CHECK_XX(x, y) requires x and y can be serialized to - // string. Use CHECK(x OP y) otherwise. - return std::make_unique(os.str()); -} - -// Inline _Pragma in macros does not work reliably on old version of MSVC and -// GCC. We wrap all comparisons in a function so that we can use #pragma to -// silence bad comparison warnings. -#define TVM_CHECK_FUNC(name, op) \ - template \ - TVM_ALWAYS_INLINE std::unique_ptr LogCheck##name(const X& x, const Y& y) { \ - if (x op y) return nullptr; \ - return LogCheckFormat(x, y); \ - } \ - TVM_ALWAYS_INLINE std::unique_ptr LogCheck##name(int x, int y) { \ - return LogCheck##name(x, y); \ - } - -#if defined(__GNUC__) || defined(__clang__) // GCC and Clang -#pragma GCC diagnostic push -#pragma GCC diagnostic ignored "-Wsign-compare" -#elif defined(_MSC_VER) // MSVC -#pragma warning(push) -#pragma warning(disable : 4389) // '==' : signed/unsigned mismatch -#endif - -TVM_CHECK_FUNC(_LT, <) -TVM_CHECK_FUNC(_GT, >) -TVM_CHECK_FUNC(_LE, <=) -TVM_CHECK_FUNC(_GE, >=) -TVM_CHECK_FUNC(_EQ, ==) -TVM_CHECK_FUNC(_NE, !=) - -#if defined(__GNUC__) || defined(__clang__) // GCC and Clang -#pragma GCC diagnostic pop -#elif defined(_MSC_VER) // MSVC -#pragma warning(pop) -#endif - } // namespace detail #define TVM_LOG_LEVEL_DEBUG 0 @@ -556,27 +402,6 @@ TVM_CHECK_FUNC(_NE, !=) #define LOG_WARNING \ ::tvm::runtime::detail::LogMessage(__FILE__, __LINE__, TVM_LOG_LEVEL_WARNING).stream() -#define TVM_CHECK_BINARY_OP(name, op, x, y) \ - if (auto __tvm__log__err = ::tvm::runtime::detail::LogCheck##name(x, y)) \ - ::tvm::runtime::detail::LogFatal(__FILE__, __LINE__).stream() \ - << "Check failed: " << #x " " #op " " #y << *__tvm__log__err << ": " - -#define CHECK(x) \ - if (!(x)) \ - ::tvm::runtime::detail::LogFatal(__FILE__, __LINE__).stream() \ - << "Check failed: (" #x << ") is false: " - -#define CHECK_LT(x, y) TVM_CHECK_BINARY_OP(_LT, <, x, y) -#define CHECK_GT(x, y) TVM_CHECK_BINARY_OP(_GT, >, x, y) -#define CHECK_LE(x, y) TVM_CHECK_BINARY_OP(_LE, <=, x, y) -#define CHECK_GE(x, y) TVM_CHECK_BINARY_OP(_GE, >=, x, y) -#define CHECK_EQ(x, y) TVM_CHECK_BINARY_OP(_EQ, ==, x, y) -#define CHECK_NE(x, y) TVM_CHECK_BINARY_OP(_NE, !=, x, y) -#define CHECK_NOTNULL(x) \ - ((x) == nullptr ? ::tvm::runtime::detail::LogFatal(__FILE__, __LINE__).stream() \ - << "Check not null: " #x << ' ', \ - (x) : (x)) // NOLINT(*) - #define LOG_IF(severity, condition) \ !(condition) ? (void)0 : ::tvm::runtime::detail::LogMessageVoidify() & LOG(severity) @@ -626,54 +451,6 @@ TVM_CHECK_FUNC(_NE, !=) DLOG_IF(INFO, ::tvm::runtime::detail::VerboseLoggingEnabled(__FILE__, (level))) \ << ::tvm::runtime::detail::ThreadLocalVLogContext()->str() -#if TVM_LOG_DEBUG -#define DCHECK(x) CHECK(x) -#define DCHECK_LT(x, y) CHECK((x) < (y)) -#define DCHECK_GT(x, y) CHECK((x) > (y)) -#define DCHECK_LE(x, y) CHECK((x) <= (y)) -#define DCHECK_GE(x, y) CHECK((x) >= (y)) -#define DCHECK_EQ(x, y) CHECK((x) == (y)) -#define DCHECK_NE(x, y) CHECK((x) != (y)) -#else -#define DCHECK(x) \ - while (false) CHECK(x) -#define DCHECK_LT(x, y) \ - while (false) CHECK((x) < (y)) -#define DCHECK_GT(x, y) \ - while (false) CHECK((x) > (y)) -#define DCHECK_LE(x, y) \ - while (false) CHECK((x) <= (y)) -#define DCHECK_GE(x, y) \ - while (false) CHECK((x) >= (y)) -#define DCHECK_EQ(x, y) \ - while (false) CHECK((x) == (y)) -#define DCHECK_NE(x, y) \ - while (false) CHECK((x) != (y)) -#endif - -#define TVM_ICHECK_INDENT " " - -#define ICHECK_BINARY_OP(name, op, x, y) \ - if (auto __tvm__log__err = ::tvm::runtime::detail::LogCheck##name(x, y)) \ - ::tvm::runtime::detail::LogFatal(__FILE__, __LINE__).stream() \ - << "InternalError: Check failed: " << #x " " #op " " #y << *__tvm__log__err << ": " - -#define ICHECK(x) \ - if (!(x)) \ - ::tvm::runtime::detail::LogFatal(__FILE__, __LINE__).stream() \ - << "InternalError: Check failed: (" #x << ") is false: " - -#define ICHECK_LT(x, y) ICHECK_BINARY_OP(_LT, <, x, y) -#define ICHECK_GT(x, y) ICHECK_BINARY_OP(_GT, >, x, y) -#define ICHECK_LE(x, y) ICHECK_BINARY_OP(_LE, <=, x, y) -#define ICHECK_GE(x, y) ICHECK_BINARY_OP(_GE, >=, x, y) -#define ICHECK_EQ(x, y) ICHECK_BINARY_OP(_EQ, ==, x, y) -#define ICHECK_NE(x, y) ICHECK_BINARY_OP(_NE, !=, x, y) -#define ICHECK_NOTNULL(x) \ - ((x) == nullptr ? ::tvm::runtime::detail::LogFatal(__FILE__, __LINE__).stream() \ - << "InternalError: Check not null: " #x << ' ', \ - (x) : (x)) // NOLINT(*) - } // namespace runtime // Re-export error types using runtime::Error; diff --git a/include/tvm/runtime/object.h b/include/tvm/runtime/object.h index d60b5712c78d..80279a4862e0 100644 --- a/include/tvm/runtime/object.h +++ b/include/tvm/runtime/object.h @@ -112,7 +112,7 @@ static_assert(static_cast(TypeIndex::kCustomStaticIndex) >= "Object types that are declared as final, " \ "using the TVM_FFI_DECLARE_OBJECT_INFO_FINAL macro."); \ ObjectName* CopyOnWrite() { \ - ICHECK(data_ != nullptr); \ + TVM_FFI_ICHECK(data_ != nullptr); \ if (!data_.unique()) { \ auto n = ::tvm::ffi::make_object(*(operator->())); \ ObjectPtr(std::move(n)).swap(data_); \ diff --git a/include/tvm/runtime/tensor.h b/include/tvm/runtime/tensor.h index d3d3f367cc74..87fbaf22a9ee 100644 --- a/include/tvm/runtime/tensor.h +++ b/include/tvm/runtime/tensor.h @@ -199,24 +199,24 @@ class Tensor : public tvm::ffi::Tensor { inline bool SaveDLTensor(support::Stream* strm, const DLTensor* tensor); inline void Tensor::CopyFrom(const DLTensor* other) { - ICHECK(data_ != nullptr); + TVM_FFI_ICHECK(data_ != nullptr); CopyFromTo(other, get_mutable()); } inline void Tensor::CopyFrom(const Tensor& other) { - ICHECK(data_ != nullptr); - ICHECK(other.data_ != nullptr); + TVM_FFI_ICHECK(data_ != nullptr); + TVM_FFI_ICHECK(other.data_ != nullptr); CopyFromTo(other.get_mutable(), get_mutable()); } inline void Tensor::CopyTo(DLTensor* other) const { - ICHECK(data_ != nullptr); + TVM_FFI_ICHECK(data_ != nullptr); CopyFromTo(get_mutable(), other); } inline void Tensor::CopyTo(const Tensor& other) const { - ICHECK(data_ != nullptr); - ICHECK(other.data_ != nullptr); + TVM_FFI_ICHECK(data_ != nullptr); + TVM_FFI_ICHECK(other.data_ != nullptr); CopyFromTo(get_mutable(), other.get_mutable()); } @@ -271,19 +271,20 @@ inline void Tensor::Save(support::Stream* strm) const { SaveDLTensor(strm, opera inline bool Tensor::Load(support::Stream* strm) { uint64_t header, reserved; - ICHECK(strm->Read(&header)) << "Invalid DLTensor file format"; - ICHECK(strm->Read(&reserved)) << "Invalid DLTensor file format"; - ICHECK(header == kTVMTensorMagic) << "Invalid DLTensor file format"; + TVM_FFI_ICHECK(strm->Read(&header)) << "Invalid DLTensor file format"; + TVM_FFI_ICHECK(strm->Read(&reserved)) << "Invalid DLTensor file format"; + TVM_FFI_ICHECK(header == kTVMTensorMagic) << "Invalid DLTensor file format"; Device dev; int ndim; DLDataType dtype; - ICHECK(strm->Read(&dev)) << "Invalid DLTensor file format"; - ICHECK(strm->Read(&ndim)) << "Invalid DLTensor file format"; - ICHECK(strm->Read(&dtype)) << "Invalid DLTensor file format"; - ICHECK_EQ(dev.device_type, kDLCPU) << "Invalid DLTensor device: can only save as CPU tensor"; + TVM_FFI_ICHECK(strm->Read(&dev)) << "Invalid DLTensor file format"; + TVM_FFI_ICHECK(strm->Read(&ndim)) << "Invalid DLTensor file format"; + TVM_FFI_ICHECK(strm->Read(&dtype)) << "Invalid DLTensor file format"; + TVM_FFI_ICHECK_EQ(dev.device_type, kDLCPU) + << "Invalid DLTensor device: can only save as CPU tensor"; std::vector shape(ndim); if (ndim != 0) { - ICHECK(strm->ReadArray(&shape[0], ndim)) << "Invalid DLTensor file format"; + TVM_FFI_ICHECK(strm->ReadArray(&shape[0], ndim)) << "Invalid DLTensor file format"; } Tensor ret = Tensor::Empty(ffi::Shape(shape), dtype, dev); int64_t num_elems = 1; @@ -292,12 +293,12 @@ inline bool Tensor::Load(support::Stream* strm) { num_elems *= ret->shape[i]; } int64_t data_byte_size; - ICHECK(strm->Read(&data_byte_size)) << "Invalid DLTensor file format"; - ICHECK(data_byte_size == num_elems * elem_bytes) << "Invalid DLTensor file format"; + TVM_FFI_ICHECK(strm->Read(&data_byte_size)) << "Invalid DLTensor file format"; + TVM_FFI_ICHECK(data_byte_size == num_elems * elem_bytes) << "Invalid DLTensor file format"; auto read_ret = strm->Read(ret->data, data_byte_size); // Only check non-empty data if (ndim > 0 && shape[0] != 0) { - ICHECK(read_ret) << "Invalid DLTensor file format"; + TVM_FFI_ICHECK(read_ret) << "Invalid DLTensor file format"; } if (!TVM_FFI_IO_NO_ENDIAN_SWAP) { ffi::ByteSwap(ret->data, elem_bytes, num_elems); diff --git a/include/tvm/runtime/vm/bytecode.h b/include/tvm/runtime/vm/bytecode.h index bc84578fd5d5..5a60febf8443 100644 --- a/include/tvm/runtime/vm/bytecode.h +++ b/include/tvm/runtime/vm/bytecode.h @@ -106,8 +106,9 @@ struct Instruction { os << "kFuncIdx"; break; default: - LOG(FATAL) << "Internal error: " - << "Invalid ArgKind with integer value " << static_cast(kind); + TVM_FFI_THROW(InternalError) + << "Internal error: " + << "Invalid ArgKind with integer value " << static_cast(kind); } return os; } @@ -173,8 +174,8 @@ struct Instruction { explicit Arg(ExecWord data) : data_(data) {} /*! \brief Construct from the kind and value. */ Arg(ArgKind kind, Index value) { - ICHECK_LE(value, kValueMaxLimit); - ICHECK_GE(value, kValueMinLimit); + TVM_FFI_ICHECK_LE(value, kValueMaxLimit); + TVM_FFI_ICHECK_GE(value, kValueMinLimit); data_ = (static_cast(kind) << kValueBit) | (value & kValueMask); } /*! \brief The underlying stored data. */ diff --git a/include/tvm/s_tir/data_layout.h b/include/tvm/s_tir/data_layout.h index 5bdad33ba099..8d1ad0ca4c09 100644 --- a/include/tvm/s_tir/data_layout.h +++ b/include/tvm/s_tir/data_layout.h @@ -258,9 +258,9 @@ class Layout : public ObjectRef { } const LayoutAxis& operator[](int32_t i) const { - ICHECK(defined()) << "Try to access axis from an undefined layout."; + TVM_FFI_ICHECK(defined()) << "Try to access axis from an undefined layout."; int32_t index = i < 0 ? static_cast(ndim() + i) : i; - ICHECK(index >= 0 && static_cast(index) < ndim()) << "Invalid index " << i; + TVM_FFI_ICHECK(index >= 0 && static_cast(index) < ndim()) << "Invalid index " << i; const tir::IterVar axis = operator->()->axes[index]; return LayoutAxis::Get(axis); } diff --git a/include/tvm/s_tir/meta_schedule/builder.h b/include/tvm/s_tir/meta_schedule/builder.h index c70c0255b826..719cf7f7ed06 100644 --- a/include/tvm/s_tir/meta_schedule/builder.h +++ b/include/tvm/s_tir/meta_schedule/builder.h @@ -164,7 +164,7 @@ class PyBuilderNode : public BuilderNode { } ffi::Array Build(const ffi::Array& build_inputs) final { - ICHECK(f_build != nullptr) << "PyBuilder's Build method not implemented!"; + TVM_FFI_ICHECK(f_build != nullptr) << "PyBuilder's Build method not implemented!"; return f_build(build_inputs); } TVM_FFI_DECLARE_OBJECT_INFO_FINAL("s_tir.meta_schedule.PyBuilder", PyBuilderNode, BuilderNode); diff --git a/include/tvm/s_tir/meta_schedule/database.h b/include/tvm/s_tir/meta_schedule/database.h index 039ca4ff055d..c6947e573473 100644 --- a/include/tvm/s_tir/meta_schedule/database.h +++ b/include/tvm/s_tir/meta_schedule/database.h @@ -270,7 +270,7 @@ class DatabaseNode : public runtime::Object { void DumpPruned(Database destination); /*! \brief Return a reference to the owned module equality method instance. */ const ModuleEquality& GetModuleEquality() const { - ICHECK(mod_eq_); + TVM_FFI_ICHECK(mod_eq_); return *mod_eq_; } @@ -398,28 +398,29 @@ class PyDatabaseNode : public DatabaseNode { } bool HasWorkload(const IRModule& mod) final { - ICHECK(f_has_workload != nullptr) << "PyDatabase's HasWorkload method not implemented!"; + TVM_FFI_ICHECK(f_has_workload != nullptr) << "PyDatabase's HasWorkload method not implemented!"; return f_has_workload(mod); } Workload CommitWorkload(const IRModule& mod) final { - ICHECK(f_commit_workload != nullptr) << "PyDatabase's CommitWorkload method not implemented!"; + TVM_FFI_ICHECK(f_commit_workload != nullptr) + << "PyDatabase's CommitWorkload method not implemented!"; return f_commit_workload(mod); } void CommitTuningRecord(const TuningRecord& record) final { - ICHECK(f_commit_tuning_record != nullptr) + TVM_FFI_ICHECK(f_commit_tuning_record != nullptr) << "PyDatabase's CommitTuningRecord method not implemented!"; f_commit_tuning_record(record); } ffi::Array GetTopK(const Workload& workload, int top_k) final { - ICHECK(f_get_top_k != nullptr) << "PyDatabase's GetTopK method not implemented!"; + TVM_FFI_ICHECK(f_get_top_k != nullptr) << "PyDatabase's GetTopK method not implemented!"; return f_get_top_k(workload, top_k); } ffi::Array GetAllTuningRecords() final { - ICHECK(f_get_all_tuning_records != nullptr) + TVM_FFI_ICHECK(f_get_all_tuning_records != nullptr) << "PyDatabase's GetAllTuningRecords method not implemented!"; return f_get_all_tuning_records(); } @@ -452,7 +453,7 @@ class PyDatabaseNode : public DatabaseNode { } int64_t Size() final { - ICHECK(f_size != nullptr) << "PyDatabase's Size method not implemented!"; + TVM_FFI_ICHECK(f_size != nullptr) << "PyDatabase's Size method not implemented!"; return f_size(); } diff --git a/include/tvm/s_tir/meta_schedule/runner.h b/include/tvm/s_tir/meta_schedule/runner.h index a08d889f8825..9fd161b2dbe5 100644 --- a/include/tvm/s_tir/meta_schedule/runner.h +++ b/include/tvm/s_tir/meta_schedule/runner.h @@ -139,7 +139,7 @@ class RunnerFutureNode : public runtime::Object { * \return A boolean indicating whether the runner has finished. */ bool Done() const { - ICHECK(f_done != nullptr) << "PyRunnerFuture's Done method not implemented!"; + TVM_FFI_ICHECK(f_done != nullptr) << "PyRunnerFuture's Done method not implemented!"; return f_done(); } /*! @@ -147,7 +147,7 @@ class RunnerFutureNode : public runtime::Object { * \return The runner's output. */ RunnerResult Result() const { - ICHECK(f_result != nullptr) << "PyRunnerFuture's Result method not implemented!"; + TVM_FFI_ICHECK(f_result != nullptr) << "PyRunnerFuture's Result method not implemented!"; return f_result(); } TVM_FFI_DECLARE_OBJECT_INFO_FINAL("s_tir.meta_schedule.RunnerFuture", RunnerFutureNode, @@ -236,7 +236,7 @@ class PyRunnerNode : public RunnerNode { } ffi::Array Run(ffi::Array runner_inputs) final { - ICHECK(f_run != nullptr) << "PyRunner's Run method not implemented!"; + TVM_FFI_ICHECK(f_run != nullptr) << "PyRunner's Run method not implemented!"; return f_run(runner_inputs); } TVM_FFI_DECLARE_OBJECT_INFO_FINAL("s_tir.meta_schedule.PyRunner", PyRunnerNode, RunnerNode); diff --git a/include/tvm/s_tir/sblock_dependence_info.h b/include/tvm/s_tir/sblock_dependence_info.h index e1ec8b958815..408afb639bb0 100644 --- a/include/tvm/s_tir/sblock_dependence_info.h +++ b/include/tvm/s_tir/sblock_dependence_info.h @@ -74,8 +74,8 @@ class SBlockDependenceInfoNode : public Object { */ SBlockScope GetSBlockScope(const StmtSRef& scope_root) const { auto it = sref2scope.find(scope_root); - CHECK(it != sref2scope.end()) - << "IndexError: Cannot find the corresponding SBlockScope to the block sref:\n" + TVM_FFI_CHECK(it != sref2scope.end(), IndexError) + << "Cannot find the corresponding SBlockScope to the block sref:\n" << ffi::GetRef(scope_root->stmt); return it->second; } diff --git a/include/tvm/s_tir/utils.h b/include/tvm/s_tir/utils.h index bedcb372d3a8..621efba737b8 100644 --- a/include/tvm/s_tir/utils.h +++ b/include/tvm/s_tir/utils.h @@ -36,7 +36,7 @@ namespace tir { */ #define TVM_SREF_AS_OR_ERR(Result, SRef, Type) \ SRef->StmtAs(); \ - ICHECK(Result) + TVM_FFI_CHECK(Result, TypeError) /*! * \brief A helper macro to convert an sref to the block it points to, @@ -46,12 +46,12 @@ namespace tir { * * \param SRef The SRef to be cast */ -#define TVM_SREF_TO_SBLOCK(SRef) \ - [&]() { \ - auto result = TVM_SREF_AS_OR_ERR(result, (SRef), ::tvm::tir::SBlockNode) \ - << "TypeError: Expects StmtSRef `" << #SRef << "` points to `Block`, but gets: " \ - << ((SRef)->stmt ? (SRef)->stmt->GetTypeKey() : "None"); \ - return result; \ +#define TVM_SREF_TO_SBLOCK(SRef) \ + [&]() { \ + auto result = TVM_SREF_AS_OR_ERR(result, (SRef), ::tvm::tir::SBlockNode) \ + << "Expects StmtSRef `" << #SRef << "` points to `Block`, but gets: " \ + << ((SRef)->stmt ? (SRef)->stmt->GetTypeKey() : "None"); \ + return result; \ }() /*! @@ -62,12 +62,12 @@ namespace tir { * * \param SRef The SRef to be cast */ -#define TVM_SREF_TO_FOR(SRef) \ - [&]() { \ - auto result = TVM_SREF_AS_OR_ERR(result, (SRef), ::tvm::tir::ForNode) \ - << "TypeError: Expects StmtSRef `" << #SRef << "` points to `Loop`, but gets: " \ - << ((SRef)->stmt ? (SRef)->stmt->GetTypeKey() : "None"); \ - return result; \ +#define TVM_SREF_TO_FOR(SRef) \ + [&]() { \ + auto result = TVM_SREF_AS_OR_ERR(result, (SRef), ::tvm::tir::ForNode) \ + << "Expects StmtSRef `" << #SRef << "` points to `Loop`, but gets: " \ + << ((SRef)->stmt ? (SRef)->stmt->GetTypeKey() : "None"); \ + return result; \ }() /*! @@ -79,7 +79,7 @@ namespace tir { */ #define TVM_TYPE_AS_OR_ERR(Result, From, Type) \ From.as(); \ - ICHECK(Result) + TVM_FFI_CHECK(Result, TypeError) /*! * \brief Downcast a TVM ObjectRef to its corresponding container using `ObjectRef::as`, @@ -87,12 +87,12 @@ namespace tir { * \param From The ObjectRef to be downcast * \param Type The type to be downcast to */ -#define TVM_TYPE_AS(From, Type) \ - [&]() { \ - auto result = TVM_TYPE_AS_OR_ERR(result, (From), Type) \ - << "TypeError: Expects `" << #From << "` to have type `" << Type::_type_key \ - << "`, but gets: " << ((From).GetTypeKey()); \ - return result; \ +#define TVM_TYPE_AS(From, Type) \ + [&]() { \ + auto result = TVM_TYPE_AS_OR_ERR(result, (From), Type) \ + << "Expects `" << #From << "` to have type `" << Type::_type_key \ + << "`, but gets: " << ((From).GetTypeKey()); \ + return result; \ }() /*! @@ -107,14 +107,14 @@ inline void SetSeqIndex(std::unordered_map& stmt2ref, const Stmt& stmt, int seq_index, bool include_loops = true) { if (const auto* realize = stmt.as()) { const SBlockNode* block = realize->block.get(); - ICHECK(stmt2ref.count(block)); + TVM_FFI_ICHECK(stmt2ref.count(block)); stmt2ref.at(block)->seq_index = seq_index; } else if (const auto* block = stmt.as()) { - ICHECK(stmt2ref.count(block)); + TVM_FFI_ICHECK(stmt2ref.count(block)); stmt2ref.at(block)->seq_index = seq_index; } else if (const auto* loop = stmt.as()) { if (!include_loops) return; - ICHECK(stmt2ref.count(loop)); + TVM_FFI_ICHECK(stmt2ref.count(loop)); stmt2ref.at(loop)->seq_index = seq_index; } } diff --git a/include/tvm/script/ir_builder/base.h b/include/tvm/script/ir_builder/base.h index 93e4d10317fe..47ed628da0ad 100644 --- a/include/tvm/script/ir_builder/base.h +++ b/include/tvm/script/ir_builder/base.h @@ -116,7 +116,7 @@ class IRBuilderFrame : public runtime::ObjectRef { * \sa IRBuilderFrameNode::EnterWithScope */ inline void EnterWithScope() { - ICHECK(data_ != nullptr); + TVM_FFI_ICHECK(data_ != nullptr); static_cast(data_.get())->EnterWithScope(); } /*! @@ -124,7 +124,7 @@ class IRBuilderFrame : public runtime::ObjectRef { * \sa IRBuilderFrameNode::ExitWithScope */ inline void ExitWithScope() { - ICHECK(data_ != nullptr); + TVM_FFI_ICHECK(data_ != nullptr); static_cast(data_.get())->ExitWithScope(); data_.reset(); } @@ -296,9 +296,10 @@ inline ffi::Optional IRBuilderNode::GetLastFrame() const { template inline TObjectRef IRBuilderNode::Get() const { using TObject = typename TObjectRef::ContainerType; - CHECK(result.defined()) << "IndexError: No result exists in IRBuilder yet"; + TVM_FFI_CHECK(result.defined(), IndexError) << "No result exists in IRBuilder yet"; const auto* n = result.as(); - CHECK(n != nullptr) << "TypeError: IRBuilder result is not of type: " << TObject::_type_key; + TVM_FFI_CHECK(n != nullptr, TypeError) + << "IRBuilder result is not of type: " << TObject::_type_key; return ffi::GetRef(n); } diff --git a/include/tvm/script/printer/ir_docsifier.h b/include/tvm/script/printer/ir_docsifier.h index b5d50d89019b..cf8c72daf89a 100644 --- a/include/tvm/script/printer/ir_docsifier.h +++ b/include/tvm/script/printer/ir_docsifier.h @@ -344,7 +344,7 @@ inline TDoc IRDocsifierNode::AsDoc(const Any& value, const AccessPath& path) con AddDocDecoration(d, obj, path, cfg); return Downcast(d); } else { - LOG(FATAL) << "TypeError: Cannot handle Any type: `" << value.GetTypeKey() << "`"; + TVM_FFI_THROW(TypeError) << "Cannot handle Any type: `" << value.GetTypeKey() << "`"; TVM_FFI_UNREACHABLE(); } } diff --git a/include/tvm/script/printer/ir_docsifier_functor.h b/include/tvm/script/printer/ir_docsifier_functor.h index 4500a7d8607b..68caa5ff4d97 100644 --- a/include/tvm/script/printer/ir_docsifier_functor.h +++ b/include/tvm/script/printer/ir_docsifier_functor.h @@ -77,9 +77,10 @@ class IRDocsifierFunctor { LOG(WARNING) << "ObjectFunctor calls un-registered function on type: " << runtime::Object::TypeIndex2Key(type_index) << " (token: " << token << ")" << ". ObjectType: " << obj->GetTypeKey() << ". Object: " << obj; - ICHECK(false) << "ObjectFunctor calls un-registered function on type: " - << runtime::Object::TypeIndex2Key(type_index) << " (token: " << token << ")" - << ". ObjectType: " << obj->GetTypeKey() << ". Object: " << obj; + TVM_FFI_ICHECK(false) << "ObjectFunctor calls un-registered function on type: " + << runtime::Object::TypeIndex2Key(type_index) << " (token: " << token + << ")" + << ". ObjectType: " << obj->GetTypeKey() << ". Object: " << obj; } /*! @@ -98,15 +99,15 @@ class IRDocsifierFunctor { } ffi::Function& slot = (*table)[type_index]; if (slot != nullptr) { - ICHECK(false) << "Dispatch for type is already registered: " - << runtime::Object::TypeIndex2Key(type_index); + TVM_FFI_ICHECK(false) << "Dispatch for type is already registered: " + << runtime::Object::TypeIndex2Key(type_index); } slot = f; return *this; } TSelf& set_fallback(ffi::Function f) { - ICHECK(!dispatch_fallback_.has_value()) << "Fallback is already defined"; + TVM_FFI_ICHECK(!dispatch_fallback_.has_value()) << "Fallback is already defined"; dispatch_fallback_ = f; return *this; } diff --git a/include/tvm/support/random_engine.h b/include/tvm/support/random_engine.h index 109a98b3d14a..9cd8cf705543 100644 --- a/include/tvm/support/random_engine.h +++ b/include/tvm/support/random_engine.h @@ -92,7 +92,7 @@ class LinearCongruentialEngine { rand_state = 1; } if (rand_state < 0) { - LOG(FATAL) << "ValueError: Random seed must be non-negative"; + TVM_FFI_THROW(ValueError) << "Random seed must be non-negative"; } return rand_state; } @@ -101,7 +101,7 @@ class LinearCongruentialEngine { * \param rand_state The random state given in result_type. */ void Seed(TRandState rand_state) { - ICHECK(rand_state_ptr_ != nullptr); + TVM_FFI_ICHECK(rand_state_ptr_ != nullptr); *rand_state_ptr_ = NormalizeSeed(rand_state); } diff --git a/include/tvm/target/target_kind.h b/include/tvm/target/target_kind.h index 5dc261d1fd6e..02ac88a9af6e 100644 --- a/include/tvm/target/target_kind.h +++ b/include/tvm/target/target_kind.h @@ -241,7 +241,7 @@ inline TargetKindAttrMap TargetKind::GetAttrMap(const ffi::String& at template inline TargetKindRegEntry& TargetKindRegEntry::set_attr(const ffi::String& attr_name, const ValueType& value, int plevel) { - ICHECK_GT(plevel, 0) << "plevel in set_attr must be greater than 0"; + TVM_FFI_ICHECK_GT(plevel, 0) << "plevel in set_attr must be greater than 0"; ffi::Any rv; rv = value; UpdateAttr(attr_name, rv, plevel); diff --git a/include/tvm/target/virtual_device.h b/include/tvm/target/virtual_device.h index ebe5eb39f580..889d0eff8904 100644 --- a/include/tvm/target/virtual_device.h +++ b/include/tvm/target/virtual_device.h @@ -235,8 +235,8 @@ class VirtualDeviceNode : public AttrsNodeReflAdapter { * Physical Devices" above. */ Device ToDevice() const { - ICHECK(device_type_int != kInvalidDeviceType); - ICHECK(virtual_device_id != -1); + TVM_FFI_ICHECK(device_type_int != kInvalidDeviceType); + TVM_FFI_ICHECK(virtual_device_id != -1); Device device; device.device_type = device_type(); device.device_id = virtual_device_id; @@ -288,7 +288,7 @@ class VirtualDevice : public ObjectRef { * The target and memory scope will be unconstrained. */ static VirtualDevice ForDeviceType(DLDeviceType device_type, int virtual_device_id = -1) { - ICHECK_GT(device_type, 0); + TVM_FFI_ICHECK_GT(device_type, 0); return VirtualDevice(device_type, virtual_device_id); } static VirtualDevice ForDeviceType(int device_type, int virtual_device_id = -1) { diff --git a/include/tvm/tir/expr_functor.h b/include/tvm/tir/expr_functor.h index c2b402225928..c38ab9a53df4 100644 --- a/include/tvm/tir/expr_functor.h +++ b/include/tvm/tir/expr_functor.h @@ -58,7 +58,7 @@ namespace tir { * }; * MyExprFunctor f; * Var x("x"); - * ICHECK_EQ(f(x + 1, 2), 3); + * TVM_FFI_ICHECK_EQ(f(x + 1, 2), 3); * \endcode * * \note Why do we need this more powerful Functor: @@ -150,7 +150,7 @@ class ExprFunctor { virtual R VisitExpr_(const FloatImmNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; virtual R VisitExpr_(const StringImmNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; virtual R VisitExprDefault_(const Object* op, Args...) { - LOG(FATAL) << "Do not have a default for " << op->GetTypeKey(); + TVM_FFI_THROW(InternalError) << "Do not have a default for " << op->GetTypeKey(); TVM_FFI_UNREACHABLE(); } diff --git a/include/tvm/tir/op.h b/include/tvm/tir/op.h index 57f868151418..050063300b71 100644 --- a/include/tvm/tir/op.h +++ b/include/tvm/tir/op.h @@ -962,7 +962,7 @@ inline PrimExpr MakeConstScalar(DataType t, ValueType value, Span span = Span()) // Use IntImm if it is a small integer uint64_t uval = static_cast(value); if (value < static_cast(0)) { - LOG(FATAL) << "cannot make uint from negative value " << value; + TVM_FFI_THROW(InternalError) << "cannot make uint from negative value " << value; } else if (uval <= static_cast(std::numeric_limits::max())) { return IntImm(t, static_cast(value), span); } else { @@ -981,7 +981,7 @@ inline PrimExpr MakeConstScalar(DataType t, ValueType value, Span span = Span()) if (static_cast(t.code()) >= static_cast(DataType::kCustomBegin)) { return FloatImm(t, static_cast(value), span); } - LOG(FATAL) << "cannot make const for type " << t; + TVM_FFI_THROW(InternalError) << "cannot make const for type " << t; throw; } diff --git a/include/tvm/tir/op_attr_types.h b/include/tvm/tir/op_attr_types.h index c87ccd741a5e..e9727f7ab3d8 100644 --- a/include/tvm/tir/op_attr_types.h +++ b/include/tvm/tir/op_attr_types.h @@ -144,7 +144,7 @@ inline std::ostream& operator<<(std::ostream& os, CallEffectKind side_effect) { return os << "kControlJump"; default: - LOG(FATAL) << "Unknown CallEffectKind: " << static_cast(side_effect); + TVM_FFI_THROW(InternalError) << "Unknown CallEffectKind: " << static_cast(side_effect); } } diff --git a/include/tvm/tir/stmt.h b/include/tvm/tir/stmt.h index 4d0029803bfd..a26ab2dc3293 100644 --- a/include/tvm/tir/stmt.h +++ b/include/tvm/tir/stmt.h @@ -1227,7 +1227,7 @@ inline const char* ForKind2String(ForKind t) { case ForKind::kThreadBinding: return "thread_binding"; } - LOG(FATAL) << "Unknown ForKind" << t; + TVM_FFI_THROW(InternalError) << "Unknown ForKind" << t; } } // namespace tir diff --git a/include/tvm/tir/stmt_functor.h b/include/tvm/tir/stmt_functor.h index 227e45ddc788..ee986c3f92f8 100644 --- a/include/tvm/tir/stmt_functor.h +++ b/include/tvm/tir/stmt_functor.h @@ -96,7 +96,7 @@ class StmtFunctor { virtual R VisitStmt_(const SBlockNode* op, Args... args) STMT_FUNCTOR_DEFAULT; virtual R VisitStmt_(const SBlockRealizeNode* op, Args... args) STMT_FUNCTOR_DEFAULT; virtual R VisitStmtDefault_(const Object* op, Args...) { - LOG(FATAL) << "Do not have a default for " << op->GetTypeKey(); + TVM_FFI_THROW(InternalError) << "Do not have a default for " << op->GetTypeKey(); TVM_FFI_UNREACHABLE(); } diff --git a/include/tvm/topi/broadcast.h b/include/tvm/topi/broadcast.h index 2aedef4c58b6..41a6ed6ca5c8 100644 --- a/include/tvm/topi/broadcast.h +++ b/include/tvm/topi/broadcast.h @@ -49,17 +49,17 @@ inline tvm::te::Tensor broadcast_to(const tvm::te::Tensor& t, const tvm::ffi::Array& output_shape, std::string name = "T_broadcast_to", std::string tag = kBroadcast) { - ICHECK_GE(output_shape.size(), t->shape.size()) + TVM_FFI_ICHECK_GE(output_shape.size(), t->shape.size()) << "Not a broadcast, output dimensionality smaller than input.\noutput: " << output_shape << "\nvs\ninput: " << t; auto bh = detail::BroadcastShape(output_shape, t->shape); - ICHECK_EQ(output_shape.size(), bh.common_shape.size()); + TVM_FFI_ICHECK_EQ(output_shape.size(), bh.common_shape.size()); ffi::Array oshape; for (size_t i = 0; i < output_shape.size(); ++i) { if (output_shape[i].as() == nullptr) { oshape.push_back(output_shape[i]); } else { - ICHECK(topi::detail::EqualCheck(output_shape[i], bh.common_shape[i])); + TVM_FFI_ICHECK(topi::detail::EqualCheck(output_shape[i], bh.common_shape[i])); oshape.push_back(bh.common_shape[i]); } } diff --git a/include/tvm/topi/detail/broadcast.h b/include/tvm/topi/detail/broadcast.h index aab6fea22d2c..9d9f78ba022e 100644 --- a/include/tvm/topi/detail/broadcast.h +++ b/include/tvm/topi/detail/broadcast.h @@ -43,8 +43,8 @@ struct BroadcastHelper { }; static inline DataType CommonType(DataType type1, DataType type2) { - ICHECK(type1.is_scalar() && type2.is_scalar()); - ICHECK(type1.code() == type2.code()); + TVM_FFI_ICHECK(type1.is_scalar() && type2.is_scalar()); + TVM_FFI_ICHECK(type1.code() == type2.code()); return DataType(type1.code(), std::max(type1.bits(), type2.bits()), /*lanes=*/1); } @@ -72,7 +72,7 @@ inline BroadcastHelper BroadcastShape(const tvm::ffi::Array& shap bh.vars1.push_front(bh.all_vars[0]); bh.vars2.push_front(bh.all_vars[0]); } else if (topi::detail::EqualCheck(one, shape1[s1_size - i])) { - ICHECK(!topi::detail::EqualCheck(one, shape2[s2_size - i])); + TVM_FFI_ICHECK(!topi::detail::EqualCheck(one, shape2[s2_size - i])); bh.common_shape.push_front(cast_if_needed(common_type, shape2[s2_size - i])); bh.vars2.push_front(bh.all_vars[0]); } else if (topi::detail::EqualCheck(one, shape2[s2_size - i])) { @@ -92,10 +92,11 @@ inline BroadcastHelper BroadcastShape(const tvm::ffi::Array& shap bh.vars1.push_front(bh.all_vars[0]); bh.vars2.push_front(bh.all_vars[0]); } else { - ICHECK(false) << "Incompatible broadcast dims: " << shape1[s1_size - i] << " and " - << shape2[s2_size - i] - << " in: " << tvm::ffi::Array(shape1.begin(), shape1.end()) - << " and " << tvm::ffi::Array(shape2.begin(), shape2.end()); + TVM_FFI_ICHECK(false) << "Incompatible broadcast dims: " << shape1[s1_size - i] << " and " + << shape2[s2_size - i] << " in: " + << tvm::ffi::Array(shape1.begin(), shape1.end()) + << " and " + << tvm::ffi::Array(shape2.begin(), shape2.end()); } } // Remaining dimensions whether on shape1 or shape2 can always be completed @@ -114,7 +115,7 @@ inline tvm::ffi::Array InputIndexFromBroadcast( const tvm::ffi::Array& ovars, const tvm::te::Tensor& T, const std::deque& my_vars, const std::deque& all_vars) { tvm::ffi::Array ivars; - ICHECK_EQ(ovars.size(), all_vars.size()); + TVM_FFI_ICHECK_EQ(ovars.size(), all_vars.size()); // N^2, could use a map but NBD. size_t expected_dims = T->shape.size(); for (size_t i = 0; i < ovars.size(); ++i) { @@ -132,7 +133,7 @@ inline tvm::ffi::Array InputIndexFromBroadcast( ivars.push_back(tvm::tir::make_zero(ovars[i].dtype())); } } - ICHECK(expected_dims == ivars.size()); + TVM_FFI_ICHECK(expected_dims == ivars.size()); return ivars; } diff --git a/include/tvm/topi/detail/constant_utils.h b/include/tvm/topi/detail/constant_utils.h index 74b4ce143cad..5bcc64ba125c 100644 --- a/include/tvm/topi/detail/constant_utils.h +++ b/include/tvm/topi/detail/constant_utils.h @@ -92,7 +92,8 @@ inline std::vector GetConstIntValues(ffi::Array exprs, const std: std::vector result; if (!exprs.defined()) return result; for (auto expr : exprs) { - ICHECK(IsConstInt(expr)) << "All elements of " << var_name << " must be constant integers"; + TVM_FFI_ICHECK(IsConstInt(expr)) + << "All elements of " << var_name << " must be constant integers"; result.push_back(GetConstInt(expr)); } return result; @@ -112,7 +113,8 @@ inline std::vector GetConstInt64Values(ffi::Array exprs, std::vector result; if (!exprs.defined()) return result; for (auto expr : exprs) { - ICHECK(IsConstInt(expr)) << "All elements of " << var_name << " must be constant integers"; + TVM_FFI_ICHECK(IsConstInt(expr)) + << "All elements of " << var_name << " must be constant integers"; result.push_back(GetConstInt(expr)); } return result; diff --git a/include/tvm/topi/detail/extern.h b/include/tvm/topi/detail/extern.h index 05543f74a50b..674cdabdab6d 100644 --- a/include/tvm/topi/detail/extern.h +++ b/include/tvm/topi/detail/extern.h @@ -65,7 +65,7 @@ inline ffi::Array make_extern(const ffi::Array>& ou const ffi::Array& inputs, FExtern fextern, std::string name, std::string tag, ::tvm::ffi::Map attrs) { - ICHECK_EQ(out_shapes.size(), out_types.size()) + TVM_FFI_ICHECK_EQ(out_shapes.size(), out_types.size()) << "make_extern: out_shapes and out_types must have equal size"; ffi::Array input_placeholders; @@ -98,7 +98,7 @@ inline ffi::Array make_extern(const ffi::Array>& ou * \return An expression representing the pack operation */ inline PrimExpr pack_buffer(Buffer buf) { - ICHECK_GT(buf->shape.size(), 0) << "buf shape must have at least one element"; + TVM_FFI_ICHECK_GT(buf->shape.size(), 0) << "buf shape must have at least one element"; auto shape = tvm::tir::Call(DataType::Handle(), tvm::tir::builtin::tvm_stack_make_shape(), buf->shape); PrimExpr strides; diff --git a/include/tvm/topi/detail/ravel_unravel.h b/include/tvm/topi/detail/ravel_unravel.h index 27d2f9180251..ffc52ae0d2a0 100644 --- a/include/tvm/topi/detail/ravel_unravel.h +++ b/include/tvm/topi/detail/ravel_unravel.h @@ -43,7 +43,7 @@ using namespace tvm::te; * \return The index after flattening */ inline PrimExpr RavelIndex(ffi::Array indices, ffi::Array shape) { - ICHECK_EQ(indices.size(), shape.size()) << "indices and shape must have equal size"; + TVM_FFI_ICHECK_EQ(indices.size(), shape.size()) << "indices and shape must have equal size"; if (indices.size() == 0U) { return 0; } diff --git a/include/tvm/topi/detail/strided_slice.h b/include/tvm/topi/detail/strided_slice.h index e75aeed8b97d..9f59828a8f43 100644 --- a/include/tvm/topi/detail/strided_slice.h +++ b/include/tvm/topi/detail/strided_slice.h @@ -55,7 +55,7 @@ inline std::tuple, std::vector, std::vector stride_vec(strides.size(), 1); if (slice_mode == "end") { for (size_t i = 0; i < strides.size(); ++i) { - ICHECK(strides[i].defined()); + TVM_FFI_ICHECK(strides[i].defined()); stride_vec[i] = GetConstInt(strides[i]); } } @@ -121,7 +121,7 @@ inline ffi::Array StridedSliceOutputShape( const std::vector& end, const std::vector& strides, const ffi::Array& axes, std::string slice_mode, const ffi::Array& begin_canonicalized, bool use_any = false) { - ICHECK(!use_any) << "StridedSliceOutputShape does not legacy use_any"; + TVM_FFI_ICHECK(!use_any) << "StridedSliceOutputShape does not legacy use_any"; const size_t src_tensor_dim = ishape.size(); ffi::Array out_shape; for (size_t i = 0; i < src_tensor_dim; ++i) { @@ -131,13 +131,13 @@ inline ffi::Array StridedSliceOutputShape( for (size_t i = 0; i < axes.size(); ++i) { if (ishape[axes[i].IntValue()]->IsInstance()) { const int64_t dim_i = GetConstInt(ishape[axes[i].IntValue()]); - ICHECK(begin_canonicalized[i]->IsInstance()); + TVM_FFI_ICHECK(begin_canonicalized[i]->IsInstance()); int64_t begin_i = GetConstInt(begin_canonicalized[i]); int64_t end_i = CanonicalizeIndex(end[i], dim_i, strides[i]); int interval = std::abs(end_i - begin_i); int slice_size = static_cast((interval + std::abs(strides[i]) - 1) / std::abs(strides[i])); - ICHECK(strides[i] < 0 ? (end_i <= begin_i) : (begin_i <= end_i)) + TVM_FFI_ICHECK(strides[i] < 0 ? (end_i <= begin_i) : (begin_i <= end_i)) << ": Input [Begin=" << begin[i] << ", End=" << end[i] << "] is invalid for axis=" << i; out_shape.Set(axes[i].IntValue(), cast(out_shape[i].dtype(), PrimExpr(slice_size))); } else { diff --git a/include/tvm/topi/elemwise.h b/include/tvm/topi/elemwise.h index 0ed082b0c140..c0eada6f1687 100644 --- a/include/tvm/topi/elemwise.h +++ b/include/tvm/topi/elemwise.h @@ -324,7 +324,7 @@ inline Tensor reinterpret(const Tensor& x, DataType type, std::string name = "te */ inline Tensor elemwise_sum(const ffi::Array& xs, std::string name = "T_elemwise_sum", std::string tag = kElementWise) { - ICHECK_GT(xs.size(), 0) << "elemwise sum must have at least one input tensor."; + TVM_FFI_ICHECK_GT(xs.size(), 0) << "elemwise sum must have at least one input tensor."; return compute( xs[0]->shape, [&](const ffi::Array& i) { diff --git a/include/tvm/topi/nn.h b/include/tvm/topi/nn.h index 36ce8594b3db..01b82cb3f648 100644 --- a/include/tvm/topi/nn.h +++ b/include/tvm/topi/nn.h @@ -100,8 +100,9 @@ inline tvm::te::Tensor leaky_relu(const tvm::te::Tensor& t, double alpha = 0.1, inline tvm::te::Tensor prelu(const tvm::te::Tensor& x, const tvm::te::Tensor& slope, const int axis = 1, std::string name = "T_prelu", std::string tag = kBroadcast) { - ICHECK((size_t)axis < x->shape.size()) << "Wrong axis (" << axis << ")value. "; - ICHECK(topi::detail::GetConstInt(slope->shape[0]) == topi::detail::GetConstInt(x->shape[axis])) + TVM_FFI_ICHECK((size_t)axis < x->shape.size()) << "Wrong axis (" << axis << ")value. "; + TVM_FFI_ICHECK(topi::detail::GetConstInt(slope->shape[0]) == + topi::detail::GetConstInt(x->shape[axis])) << "Wrong slope shape received."; return tvm::te::compute( @@ -164,8 +165,8 @@ inline tvm::te::Tensor pad( } arith::Analyzer analyzer; - ICHECK_GE(pad_before.size(), 1); - ICHECK_EQ(pad_before.size(), pad_after.size()); + TVM_FFI_ICHECK_GE(pad_before.size(), 1); + TVM_FFI_ICHECK_EQ(pad_before.size(), pad_after.size()); tvm::ffi::Array pad_before_int32; tvm::ffi::Array pad_after_int32; @@ -269,8 +270,8 @@ inline tvm::te::Tensor conv2d_nchw(const tvm::te::Tensor& I, const tvm::te::Tens int pad_h = 0, int pad_w = 0, int stride_h = 1, int stride_w = 1, std::string name = "T_conv2d_nchw", std::string tag = kConv2dNCHW) { - ICHECK_EQ(4, I->shape.size()); - ICHECK_EQ(4, W->shape.size()); + TVM_FFI_ICHECK_EQ(4, I->shape.size()); + TVM_FFI_ICHECK_EQ(4, W->shape.size()); auto pH = I->shape[2]; auto pW = I->shape[3]; tvm::ffi::Array output_shape{ @@ -313,8 +314,8 @@ inline tvm::te::Tensor conv2d_hwcn(const tvm::te::Tensor& I, const tvm::te::Tens int pad_h = 0, int pad_w = 0, int stride_h = 1, int stride_w = 1, std::string name = "T_conv2d_hwcn", std::string tag = kConv2dHWCN) { - ICHECK_EQ(4, I->shape.size()); - ICHECK_EQ(4, W->shape.size()); + TVM_FFI_ICHECK_EQ(4, I->shape.size()); + TVM_FFI_ICHECK_EQ(4, W->shape.size()); auto pH = I->shape[2]; auto pW = I->shape[3]; tvm::ffi::Array output_shape{ @@ -358,8 +359,8 @@ inline tvm::te::Tensor depthwise_conv2d_nchw(const tvm::te::Tensor& I, const tvm int stride_w = 1, std::string name = "T_depthwise_conv2d_nchw", std::string tag = kDepthwiseConv2dNCHW) { - ICHECK_EQ(4, I->shape.size()); - ICHECK_EQ(4, W->shape.size()); + TVM_FFI_ICHECK_EQ(4, I->shape.size()); + TVM_FFI_ICHECK_EQ(4, W->shape.size()); auto pH = I->shape[2]; auto pW = I->shape[3]; auto pCM = W->shape[1]; // channel_multiplier @@ -387,8 +388,8 @@ inline tvm::te::Tensor depthwise_conv2d_nhwc(const tvm::te::Tensor& I, const tvm int stride_w = 1, std::string name = "T_depthwise_conv2d_nhwc", std::string tag = kDepthwiseConv2dNHWC) { - ICHECK_EQ(4, I->shape.size()); - ICHECK_EQ(4, W->shape.size()); + TVM_FFI_ICHECK_EQ(4, I->shape.size()); + TVM_FFI_ICHECK_EQ(4, W->shape.size()); auto pH = I->shape[1]; auto pW = I->shape[2]; auto pCM = W->shape[1]; // channel_multiplier @@ -436,8 +437,8 @@ inline tvm::te::Tensor group_conv2d_ngchw(const tvm::te::Tensor& I, const tvm::t int stride_w = 1, std::string name = "T_group_conv2d_ngchw", std::string tag = kGroupConv2d) { - ICHECK_EQ(5, I->shape.size()); - ICHECK_EQ(5, W->shape.size()); + TVM_FFI_ICHECK_EQ(5, I->shape.size()); + TVM_FFI_ICHECK_EQ(5, W->shape.size()); auto pH = I->shape[2]; auto pW = I->shape[3]; tvm::ffi::Array output_shape{ @@ -487,8 +488,8 @@ inline tvm::te::Tensor space_to_batch_nd(const tvm::te::Tensor& data, std::string name = "space_to_batch_nd", std::string tag = kInjective) { tvm::te::Tensor padded_t; - CHECK_EQ(pad_before.size(), pad_after.size()); - CHECK_EQ(block_shape.size(), pad_before.size()) + TVM_FFI_ICHECK_EQ(pad_before.size(), pad_after.size()); + TVM_FFI_ICHECK_EQ(block_shape.size(), pad_before.size()) << "Paddings must be provided for each spatial dimension"; tvm::ffi::Array pad_before_int32; tvm::ffi::Array pad_after_int32; @@ -526,7 +527,7 @@ inline tvm::te::Tensor space_to_batch_nd(const tvm::te::Tensor& data, for (size_t i = 1; i <= num_block_dims; i++) { int padded_input = static_cast(GetConstInt(padded_shape[i])); int block_size = static_cast(GetConstInt(block_shape[i - 1])); - CHECK_EQ((padded_input % block_size), 0) + TVM_FFI_ICHECK_EQ((padded_input % block_size), 0) << "(" << i << ")th " "Input dimension after padding (" @@ -628,7 +629,7 @@ inline tvm::te::Tensor batch_to_space_nd(const tvm::te::Tensor& data, int begin_i = static_cast(GetConstInt(crop_begin_list[i - 1])); int end_i = static_cast(GetConstInt(crop_end_list[i - 1])); int out_i = static_cast(GetConstInt(r_p_shape[i])); - CHECK_GT(out_i, (begin_i + end_i)) + TVM_FFI_ICHECK_GT(out_i, (begin_i + end_i)) << "Incorrect crop sizes for (" << i << ")th dim, can not crop more than" << " output size" << out_i << " vs " << (begin_i + end_i); begin_idx.push_back(begin_i); @@ -699,7 +700,7 @@ inline Tensor nll_loss(const Tensor& predictions, const Tensor& targets, const T tvm::tir::make_const(predictions->dtype, 0)); }, name, tag); - ICHECK(T->shape.size() != 0); + TVM_FFI_ICHECK(T->shape.size() != 0); if (reduction == "mean") { auto W = tvm::te::compute( targets->shape, diff --git a/include/tvm/topi/nn/bnn.h b/include/tvm/topi/nn/bnn.h index 2cc494eaa9d4..e474cff16941 100644 --- a/include/tvm/topi/nn/bnn.h +++ b/include/tvm/topi/nn/bnn.h @@ -52,7 +52,7 @@ inline tvm::te::Tensor binarize_pack(const tvm::te::Tensor& data, int axis, std::string name = "PackedInput", std::string tag = "binarize_pack") { auto ishape = data->shape; - ICHECK_EQ(GetConstInt(ishape[axis]) % 32, 0) + TVM_FFI_ICHECK_EQ(GetConstInt(ishape[axis]) % 32, 0) << "binarize_pack: axis size must be a multiple of 32"; arith::Analyzer analyzer; @@ -99,10 +99,10 @@ inline tvm::te::Tensor binarize_pack(const tvm::te::Tensor& data, int axis, * \return Tensor with shape [batch, out_dim], dtype is float32 */ inline tvm::te::Tensor binary_dense(const tvm::te::Tensor& data, const tvm::te::Tensor& weight) { - ICHECK_EQ(data->shape.size(), 2) << "binary_dense requires 2-D data"; - ICHECK_EQ(weight->shape.size(), 2) << "binary_dense requires 2-D weight"; - ICHECK_EQ(data->dtype, DataType::UInt(32)) << "binary_dense requires uint32 data"; - ICHECK_EQ(weight->dtype, DataType::UInt(32)) << "binary_dense requires uint32 weight"; + TVM_FFI_ICHECK_EQ(data->shape.size(), 2) << "binary_dense requires 2-D data"; + TVM_FFI_ICHECK_EQ(weight->shape.size(), 2) << "binary_dense requires 2-D weight"; + TVM_FFI_ICHECK_EQ(data->dtype, DataType::UInt(32)) << "binary_dense requires uint32 data"; + TVM_FFI_ICHECK_EQ(weight->dtype, DataType::UInt(32)) << "binary_dense requires uint32 weight"; auto batch = data->shape[0]; auto in_dim = data->shape[1]; diff --git a/include/tvm/topi/nn/dense.h b/include/tvm/topi/nn/dense.h index 113002dc2d88..be0030cd40d5 100644 --- a/include/tvm/topi/nn/dense.h +++ b/include/tvm/topi/nn/dense.h @@ -47,10 +47,10 @@ using namespace tvm::te; */ inline tvm::te::Tensor dense(const tvm::te::Tensor& data, const tvm::te::Tensor& weight, const tvm::te::Tensor& bias, const DataType& out_dtype) { - ICHECK_EQ(data->shape.size(), 2) << "dense requires 2-D data"; - ICHECK_EQ(weight->shape.size(), 2) << "dense requires 2-D weight"; + TVM_FFI_ICHECK_EQ(data->shape.size(), 2) << "dense requires 2-D data"; + TVM_FFI_ICHECK_EQ(weight->shape.size(), 2) << "dense requires 2-D weight"; if (bias.defined()) { - ICHECK_EQ(bias->shape.size(), 1) << "dense requires 1-D bias"; + TVM_FFI_ICHECK_EQ(bias->shape.size(), 1) << "dense requires 1-D bias"; } auto batch = data->shape[0]; diff --git a/include/tvm/topi/nn/dilate.h b/include/tvm/topi/nn/dilate.h index 816d489c400e..52ef33c80249 100644 --- a/include/tvm/topi/nn/dilate.h +++ b/include/tvm/topi/nn/dilate.h @@ -45,7 +45,7 @@ using namespace tvm::te; * \return The logical conjunction expression */ PrimExpr all(ffi::Array args) { - ICHECK_GT(args.size(), 0) << "all requires at least one argument"; + TVM_FFI_ICHECK_GT(args.size(), 0) << "all requires at least one argument"; PrimExpr ret = args[0]; for (size_t i = 1; i < args.size(); ++i) { @@ -70,8 +70,8 @@ PrimExpr all(ffi::Array args) { inline Tensor dilate(const Tensor& x, ffi::Array strides, double dilation_value, std::string name = "tensor", std::string tag = kInjective) { auto n = x->shape.size(); - ICHECK_EQ(n, strides.size()) << "strides size (" << strides.size() - << ") must match dimension of x (" << n << ")"; + TVM_FFI_ICHECK_EQ(n, strides.size()) + << "strides size (" << strides.size() << ") must match dimension of x (" << n << ")"; ffi::Array out_shape; arith::Analyzer analyzer; diff --git a/include/tvm/topi/nn/group_norm.h b/include/tvm/topi/nn/group_norm.h index 9c03b682407d..b0e71c7cf777 100644 --- a/include/tvm/topi/nn/group_norm.h +++ b/include/tvm/topi/nn/group_norm.h @@ -43,9 +43,9 @@ inline Tensor group_norm(const Tensor& data, const Tensor& gamma, const Tensor& const auto& data_type = data->dtype; const auto& gamma_type = gamma.defined() ? gamma->dtype : data_type; const auto& beta_type = beta.defined() ? beta->dtype : data_type; - ICHECK(data_type == gamma_type && data_type == beta_type) + TVM_FFI_ICHECK(data_type == gamma_type && data_type == beta_type) << "group_norm: data, gamma and beta must have the same type"; - ICHECK(data_type == DataType::Float(32) || data_type == DataType::Float(16)) + TVM_FFI_ICHECK(data_type == DataType::Float(32) || data_type == DataType::Float(16)) << "group_norm: only support float32 and float16 for now"; bool is_float16 = data_type == DataType::Float(16); // reshape data C -> G, C/G @@ -88,7 +88,7 @@ inline Tensor group_norm(const Tensor& data, const Tensor& gamma, const Tensor& } else if (new_axis > channel_axis) { new_axes.push_back(new_axis + 1); } else { - ICHECK(false) << "axes can not contain channel axis"; + TVM_FFI_ICHECK(false) << "axes can not contain channel axis"; } } std::sort(new_axes.begin(), new_axes.end()); diff --git a/include/tvm/topi/nn/instance_norm.h b/include/tvm/topi/nn/instance_norm.h index c6a10ec89f0a..66baf3e2f5c1 100644 --- a/include/tvm/topi/nn/instance_norm.h +++ b/include/tvm/topi/nn/instance_norm.h @@ -56,14 +56,14 @@ inline Tensor instance_norm(const Tensor& data, const Tensor& gamma, const Tenso const auto& data_type = data->dtype; const auto& gamma_type = gamma.defined() ? gamma->dtype : data_type; const auto& beta_type = beta.defined() ? beta->dtype : data_type; - ICHECK(data_type == gamma_type && data_type == beta_type) + TVM_FFI_ICHECK(data_type == gamma_type && data_type == beta_type) << "instance_norm: data, gamma and beta must have the same type"; - ICHECK(data_type == DataType::Float(32) || data_type == DataType::Float(16)) + TVM_FFI_ICHECK(data_type == DataType::Float(32) || data_type == DataType::Float(16)) << "instance_norm: only support float32 and float16 for now"; bool is_float16 = data_type == DataType::Float(16); // sum x and x^2 auto ndim = data->shape.size(); - ICHECK_NE(ndim, 0) << "Cannot reduce a 0 dim Tensor"; + TVM_FFI_ICHECK_NE(ndim, 0) << "Cannot reduce a 0 dim Tensor"; auto real_axis = GetRealAxis(static_cast(ndim), axis); auto reduce_axes = MakeReduceAxes(real_axis, data); auto target_shape = diff --git a/include/tvm/topi/nn/layer_norm.h b/include/tvm/topi/nn/layer_norm.h index 7caa30b0a23b..6c3409aca3a9 100644 --- a/include/tvm/topi/nn/layer_norm.h +++ b/include/tvm/topi/nn/layer_norm.h @@ -54,14 +54,14 @@ inline Tensor layer_norm(const Tensor& data, const Tensor& gamma, const Tensor& const auto& data_type = data->dtype; const auto& gamma_type = gamma.defined() ? gamma->dtype : data_type; const auto& beta_type = beta.defined() ? beta->dtype : data_type; - ICHECK(data_type == gamma_type && data_type == beta_type) + TVM_FFI_ICHECK(data_type == gamma_type && data_type == beta_type) << "layer_norm: data, gamma and beta must have the same type"; - ICHECK(data_type == DataType::Float(32) || data_type == DataType::Float(16)) + TVM_FFI_ICHECK(data_type == DataType::Float(32) || data_type == DataType::Float(16)) << "layer_norm: only support float32 and float16 for now"; bool is_float16 = data_type == DataType::Float(16); // sum x and x^2 auto ndim = data->shape.size(); - ICHECK_NE(ndim, 0) << "Cannot reduce a 0 dim Tensor"; + TVM_FFI_ICHECK_NE(ndim, 0) << "Cannot reduce a 0 dim Tensor"; auto real_axis = GetRealAxis(static_cast(ndim), axis); auto reduce_axes = MakeReduceAxes(real_axis, data); auto target_shape = diff --git a/include/tvm/topi/nn/local_response_norm.h b/include/tvm/topi/nn/local_response_norm.h index 119ab0c19eb0..0c045a1631bc 100644 --- a/include/tvm/topi/nn/local_response_norm.h +++ b/include/tvm/topi/nn/local_response_norm.h @@ -52,10 +52,10 @@ using namespace tvm::te; inline Tensor lrn(const Tensor& data, int size, int axis = 1, float alpha = 0.0001, float beta = 0.75, float bias = 2, std::string name = "tensor", std::string tag = kBroadcast) { - ICHECK_EQ(data->shape.size(), 4) << "LRN requires 4-D input"; - ICHECK_EQ(size % 2, 1) << "size should be odd number"; - ICHECK(axis == 1 || axis == 3) << "axis should be 1 or 3 for NCHW and NHWC"; - ICHECK(data->dtype.is_float()) << "datatype should be float"; + TVM_FFI_ICHECK_EQ(data->shape.size(), 4) << "LRN requires 4-D input"; + TVM_FFI_ICHECK_EQ(size % 2, 1) << "size should be odd number"; + TVM_FFI_ICHECK(axis == 1 || axis == 3) << "axis should be 1 or 3 for NCHW and NHWC"; + TVM_FFI_ICHECK(data->dtype.is_float()) << "datatype should be float"; auto input_shape = data->shape; ffi::Array pad_before{0, 0, 0, 0}; ffi::Array pad_after{0, 0, 0, 0}; diff --git a/include/tvm/topi/nn/pooling.h b/include/tvm/topi/nn/pooling.h index 3caf7bf1f7d2..970d74b1c612 100644 --- a/include/tvm/topi/nn/pooling.h +++ b/include/tvm/topi/nn/pooling.h @@ -52,11 +52,11 @@ inline Tensor pool_grad_impl(const Tensor& out_grad, const Tensor& x, const ffi::Array& padding_size, PoolType pool_type, bool ceil_mode, const size_t height_axis, const size_t width_axis, bool count_include_pad) { - ICHECK(out_grad->shape.size() >= 2) << "Pooling grad output must >= 2-D (H, W)"; - ICHECK(x->shape.size() >= 2) << "Pooling input must >= 2-D (H, W)"; - ICHECK_EQ(kernel_size.size(), 2) << "Pooling kernel_size must have 2 elements"; - ICHECK_EQ(stride_size.size(), 2) << "Pooling stride_size must have 2 elements"; - ICHECK_EQ(padding_size.size(), 4) << "Pooling padding_size must have 4 elements"; + TVM_FFI_ICHECK(out_grad->shape.size() >= 2) << "Pooling grad output must >= 2-D (H, W)"; + TVM_FFI_ICHECK(x->shape.size() >= 2) << "Pooling input must >= 2-D (H, W)"; + TVM_FFI_ICHECK_EQ(kernel_size.size(), 2) << "Pooling kernel_size must have 2 elements"; + TVM_FFI_ICHECK_EQ(stride_size.size(), 2) << "Pooling stride_size must have 2 elements"; + TVM_FFI_ICHECK_EQ(padding_size.size(), 4) << "Pooling padding_size must have 4 elements"; auto kernel_height = kernel_size[0]; auto kernel_width = kernel_size[1]; @@ -298,7 +298,8 @@ inline Tensor pool_grad(const Tensor& out_grad, const Tensor& x, bool ceil_mode, const std::string& layout = "NCHW", bool count_include_pad = true) { int height_axis = -1, width_axis = -1; - ICHECK(find_height_width(layout, &height_axis, &width_axis)) << "Unsupported layout " << layout; + TVM_FFI_ICHECK(find_height_width(layout, &height_axis, &width_axis)) + << "Unsupported layout " << layout; return pool_grad_impl(out_grad, x, kernel_size, stride_size, padding_size, pool_type, ceil_mode, height_axis, width_axis, count_include_pad); } @@ -325,7 +326,7 @@ inline PrimExpr end_index(const Var& out_index, const PrimExpr& odim, const Prim inline Tensor adaptive_pool_impl(const Tensor& x, const ffi::Array& output_size, PoolType pool_type, const std::vector& axes) { const auto n_dim = output_size.size(); - ICHECK_EQ(axes.size(), n_dim) << "The number of axes not equal to the in/out dimension"; + TVM_FFI_ICHECK_EQ(axes.size(), n_dim) << "The number of axes not equal to the in/out dimension"; ffi::Array data_shape = x->shape; ffi::Array out_shape = data_shape; @@ -427,7 +428,8 @@ inline Tensor adaptive_pool_impl(const Tensor& x, const ffi::Array& ou inline Tensor adaptive_pool(const Tensor& x, const ffi::Array& output_size, PoolType pool_type, const std::string& layout = "NCHW") { int height_axis = -1, width_axis = -1; - ICHECK(find_height_width(layout, &height_axis, &width_axis)) << "Unsupported layout " << layout; + TVM_FFI_ICHECK(find_height_width(layout, &height_axis, &width_axis)) + << "Unsupported layout " << layout; return adaptive_pool_impl(x, output_size, pool_type, {height_axis, width_axis}); } @@ -442,7 +444,7 @@ inline Tensor adaptive_pool(const Tensor& x, const ffi::Array& output_ inline Tensor adaptive_pool3d(const Tensor& x, const ffi::Array& output_size, PoolType pool_type, const std::string& layout = "NCDHW") { int depth_axis = -1, height_axis = -1, width_axis = -1; - ICHECK(find_depth_height_width(layout, &depth_axis, &height_axis, &width_axis)) + TVM_FFI_ICHECK(find_depth_height_width(layout, &depth_axis, &height_axis, &width_axis)) << "Unsupported layout " << layout; return adaptive_pool_impl(x, output_size, pool_type, {depth_axis, height_axis, width_axis}); } @@ -458,7 +460,7 @@ inline Tensor adaptive_pool3d(const Tensor& x, const ffi::Array& outpu inline Tensor adaptive_pool1d(const Tensor& x, const ffi::Array& output_size, PoolType pool_type, const std::string& layout = "NCW") { int width_axis = -1; - ICHECK(find_width(layout, &width_axis)) << "Unsupported layout " << layout; + TVM_FFI_ICHECK(find_width(layout, &width_axis)) << "Unsupported layout " << layout; return adaptive_pool_impl(x, output_size, pool_type, {width_axis}); } @@ -514,10 +516,12 @@ inline Tensor pool_impl_nd(const Tensor& x, const ffi::Array& kernel_s bool ceil_mode, const std::vector& axis, bool count_include_pad) { int k_size = kernel_size.size(); int x_size = x->shape.size(); - ICHECK_EQ(stride_size.size(), k_size) << "Pooling stride_size must have same elements as kernel"; - ICHECK_EQ(padding_size.size(), k_size * 2) << "Pooling padding_size must has double elements of" - " kernel"; - ICHECK_EQ(axis.size(), k_size) << "axis must have same elements as kernel"; + TVM_FFI_ICHECK_EQ(stride_size.size(), k_size) + << "Pooling stride_size must have same elements as kernel"; + TVM_FFI_ICHECK_EQ(padding_size.size(), k_size * 2) + << "Pooling padding_size must has double elements of" + " kernel"; + TVM_FFI_ICHECK_EQ(axis.size(), k_size) << "axis must have same elements as kernel"; ffi::Array daxis; std::vector kernel(k_size); @@ -707,7 +711,7 @@ inline Tensor pool1d(const Tensor& x, const ffi::Array& kernel_size, const ffi::Array& padding_size, PoolType pool_type, bool ceil_mode, const std::string& layout = "NCW", bool count_include_pad = true) { int width_axis = -1; - ICHECK(find_width(layout, &width_axis)) << "Unsupported layout " << layout; + TVM_FFI_ICHECK(find_width(layout, &width_axis)) << "Unsupported layout " << layout; std::vector axis = {width_axis}; return pool_impl_nd(x, kernel_size, stride_size, dilation_size, padding_size, pool_type, ceil_mode, axis, count_include_pad); @@ -749,7 +753,8 @@ inline Tensor pool2d(const Tensor& x, const ffi::Array& kernel_size, const ffi::Array& padding_size, PoolType pool_type, bool ceil_mode, const std::string& layout = "NCHW", bool count_include_pad = true) { int height_axis = -1, width_axis = -1; - ICHECK(find_height_width(layout, &height_axis, &width_axis)) << "Unsupported layout " << layout; + TVM_FFI_ICHECK(find_height_width(layout, &height_axis, &width_axis)) + << "Unsupported layout " << layout; std::vector axis = {height_axis, width_axis}; return pool_impl_nd(x, kernel_size, stride_size, dilation_size, padding_size, pool_type, ceil_mode, axis, count_include_pad); @@ -792,7 +797,7 @@ inline Tensor pool3d(const Tensor& x, const ffi::Array& kernel_size, const ffi::Array& padding_size, PoolType pool_type, bool ceil_mode, const std::string& layout = "NCDHW", bool count_include_pad = true) { int depth_axis = -1, height_axis = -1, width_axis = -1; - ICHECK(find_depth_height_width(layout, &depth_axis, &height_axis, &width_axis)) + TVM_FFI_ICHECK(find_depth_height_width(layout, &depth_axis, &height_axis, &width_axis)) << "Unsupported layout " << layout; std::vector axis = {depth_axis, height_axis, width_axis}; return pool_impl_nd(x, kernel_size, stride_size, dilation_size, padding_size, pool_type, diff --git a/include/tvm/topi/nn/rms_norm.h b/include/tvm/topi/nn/rms_norm.h index 66a2ae62dfec..4f6292d968ac 100644 --- a/include/tvm/topi/nn/rms_norm.h +++ b/include/tvm/topi/nn/rms_norm.h @@ -52,7 +52,7 @@ inline Tensor rms_norm(const Tensor& data, const Tensor& weight, const ffi::Arra std::string tag = kInjective) { const auto& data_type = data->dtype; const auto& weight_type = weight.defined() ? weight->dtype : data_type; - ICHECK(data_type == weight_type) << "rms_norm: data and weight must have the same type"; + TVM_FFI_ICHECK(data_type == weight_type) << "rms_norm: data and weight must have the same type"; const auto& data_fp32 = cast(data, DataType::Float(32)); const auto& weight_fp32 = cast(weight, DataType::Float(32)); @@ -61,7 +61,7 @@ inline Tensor rms_norm(const Tensor& data, const Tensor& weight, const ffi::Arra auto square_sum = sum(square, axis, /*keepdims=*/false, /*atleast1d=*/true); auto ndim = data_fp32->shape.size(); - ICHECK_NE(ndim, 0) << "Cannot reduce a 0 dim Tensor"; + TVM_FFI_ICHECK_NE(ndim, 0) << "Cannot reduce a 0 dim Tensor"; auto real_axis = GetRealAxis(static_cast(ndim), axis); auto reduce_extent = make_const(data_fp32->dtype, 1); for (int i : real_axis) { diff --git a/include/tvm/topi/nn/softmax.h b/include/tvm/topi/nn/softmax.h index f58d66ece139..8b18ebe4b686 100644 --- a/include/tvm/topi/nn/softmax.h +++ b/include/tvm/topi/nn/softmax.h @@ -54,7 +54,7 @@ inline Tensor softmax(const Tensor& x, int axis = -1, std::string name = "tensor if (axis < 0) { axis = ndim + axis; } - ICHECK_LT(axis, ndim) << "axis parameter should be less than input dim"; + TVM_FFI_ICHECK_LT(axis, ndim) << "axis parameter should be less than input dim"; auto k1 = tvm::te::reduce_axis(Range(0, input_shape[axis]), "k1"); auto k2 = tvm::te::reduce_axis(Range(0, input_shape[axis]), "k2"); @@ -126,7 +126,7 @@ inline Tensor softmax(const Tensor& x, int axis = -1, std::string name = "tensor */ inline Tensor log_softmax(const Tensor& x, std::string name = "tensor", std::string tag = "log_softmax_output") { - ICHECK_EQ(x->shape.size(), 2) << "Log softmax requires 2-D input"; + TVM_FFI_ICHECK_EQ(x->shape.size(), 2) << "Log softmax requires 2-D input"; PrimExpr m = x->shape[0]; PrimExpr n = x->shape[1]; diff --git a/include/tvm/topi/reduction.h b/include/tvm/topi/reduction.h index fda754061bbe..5345cc8e0ea9 100644 --- a/include/tvm/topi/reduction.h +++ b/include/tvm/topi/reduction.h @@ -75,8 +75,8 @@ inline std::vector GetRealAxis(int ndim, const ffi::Optional(val)); } std::sort(real_axis.begin(), real_axis.end()); @@ -184,7 +184,7 @@ inline Tensor DoCommReduce(const Tensor& data, FReduce func, inline Tensor CommReduce(const Tensor& data, const ffi::Optional>& axis, FReduce func, bool keepdims, bool atleast1d) { auto ndim = data->shape.size(); - ICHECK_NE(ndim, 0) << "Cannot reduce a 0 dim Tensor"; + TVM_FFI_ICHECK_NE(ndim, 0) << "Cannot reduce a 0 dim Tensor"; auto real_axis = GetRealAxis(static_cast(ndim), axis); auto target_shape = MakeReduceTargetShape(real_axis, data, keepdims, atleast1d); return DoCommReduce(data, func, target_shape, real_axis, @@ -207,7 +207,7 @@ inline Tensor CommReduce(const Tensor& data, const ffi::Optional>& axis, FCommReduce func, bool keepdims, bool atleast1d) { auto ndim = data->shape.size(); - ICHECK_NE(ndim, 0) << "Cannot reduce a 0 dim Tensor"; + TVM_FFI_ICHECK_NE(ndim, 0) << "Cannot reduce a 0 dim Tensor"; auto real_axis = GetRealAxis(static_cast(ndim), axis); auto reduce_axes = MakeReduceAxes(real_axis, data); auto target_shape = MakeReduceTargetShape(real_axis, data, keepdims, atleast1d); @@ -340,7 +340,7 @@ inline Tensor collapse_sum(const Tensor& data, ffi::Array target_shape int isize = data->shape.size(); int osize = target_shape.size(); - ICHECK_GE(isize, osize) + TVM_FFI_ICHECK_GE(isize, osize) << "Invalid collapse: input dimensionality smaller than output dimensionality.\ninput shape: " << data->shape << "\nvs\noutput shape: " << target_shape; @@ -591,7 +591,7 @@ inline Tensor prod(const Tensor& data, const ffi::Optional>& inline FCommReduce MakeTupleSumReducer() { auto fcombine = [](ffi::Array lhs, ffi::Array rhs) { ffi::Array result; - ICHECK_EQ(lhs.size(), rhs.size()); + TVM_FFI_ICHECK_EQ(lhs.size(), rhs.size()); result.reserve(lhs.size()); for (size_t i = 0; i < lhs.size(); ++i) { result.push_back(lhs[i] + rhs[i]); diff --git a/include/tvm/topi/transform.h b/include/tvm/topi/transform.h index 6f395575cefa..24a1521c1e96 100644 --- a/include/tvm/topi/transform.h +++ b/include/tvm/topi/transform.h @@ -76,13 +76,14 @@ using namespace topi::detail; inline Tensor sliding_window(const Tensor& x, int axis, ffi::Array window_shape, ffi::Array strides, std::string name = "T_sliding_window", std::string tag = "") { - CHECK_GE(axis, 0); + TVM_FFI_ICHECK_GE(axis, 0); auto _axis = size_t(axis); - CHECK_LT(_axis, x->shape.size()) << "axis must be a valid dimension index of x."; - CHECK_EQ(x->shape.size() - _axis, window_shape.size()) + TVM_FFI_ICHECK_LT(_axis, x->shape.size()) << "axis must be a valid dimension index of x."; + TVM_FFI_ICHECK_EQ(x->shape.size() - _axis, window_shape.size()) << "There must be a window shape for every dimension of x " << "over which we are sliding the window."; - CHECK_EQ(strides.size(), window_shape.size()) << "Windows and strides should be the same length."; + TVM_FFI_ICHECK_EQ(strides.size(), window_shape.size()) + << "Windows and strides should be the same length."; // Compute the new shape. ffi::Array new_shape; @@ -109,7 +110,7 @@ inline Tensor sliding_window(const Tensor& x, int axis, ffi::Array wind new_shape.push_back(window_shape[i]); } - ICHECK(new_shape.size() == _axis + 2 * window_shape.size()); + TVM_FFI_ICHECK(new_shape.size() == _axis + 2 * window_shape.size()); return compute( new_shape, @@ -133,7 +134,7 @@ inline Tensor sliding_window(const Tensor& x, int axis, ffi::Array wind idx.push_back(window_idx * stride + idx_within_window); } - ICHECK(idx.size() == x->shape.size()); + TVM_FFI_ICHECK(idx.size() == x->shape.size()); return x(idx); }, @@ -155,11 +156,11 @@ inline Tensor sliding_window(const Tensor& x, int axis, ffi::Array wind inline Tensor expand_dims(const Tensor& x, int axis, int num_newaxis = 1, std::string name = "T_expand_dims", std::string tag = kBroadcast) { int ndim = static_cast(x->shape.size()); - ICHECK(-ndim - 1 <= axis && axis <= ndim) + TVM_FFI_ICHECK(-ndim - 1 <= axis && axis <= ndim) << "expand_dims only accepts `axis` in [-data.ndim - 1, data.ndim]" << ", but got axis = " << axis << ", and data.ndim = " << ndim; - ICHECK(num_newaxis >= 0) << "expand_dims only accepts `num_newaxis >= 0`" - << ", but got num_newaxis = " << num_newaxis; + TVM_FFI_ICHECK(num_newaxis >= 0) << "expand_dims only accepts `num_newaxis >= 0`" + << ", but got num_newaxis = " << num_newaxis; if (axis < 0) { // Calculate offset from last dimension axis = ndim + axis + 1; @@ -218,13 +219,14 @@ inline Tensor transpose(const Tensor& x, ffi::Optional> opt_ new_axis = static_cast(x->shape.size()) + axis; axes.Set(i, new_axis); } - ICHECK((new_axis >= 0) && (new_axis < static_cast(x->shape.size()))) + TVM_FFI_ICHECK((new_axis >= 0) && (new_axis < static_cast(x->shape.size()))) << "axis=" << axis << " is invalid for the " << static_cast(x->shape.size()) << "-dimensional input tensor"; for (size_t j = 0; j < axes.size(); ++j) { if (i != j) { - ICHECK(new_axis != static_cast(axes[j]->value)) << "repeated axis in transpose"; + TVM_FFI_ICHECK(new_axis != static_cast(axes[j]->value)) + << "repeated axis in transpose"; } } new_shape.push_back(x->shape[new_axis]); @@ -273,14 +275,14 @@ inline Tensor reverse_sequence(const Tensor& x, const Tensor& seq_lengths, int s batch_axis = static_cast(x->shape.size()) + batch_axis; } - ICHECK(seq_lengths_dim == 1) << "seq_lengths should be 1D vector"; + TVM_FFI_ICHECK(seq_lengths_dim == 1) << "seq_lengths should be 1D vector"; - ICHECK(GetConstInt(seq_lengths->shape[0]) == GetConstInt(x->shape[batch_axis])) + TVM_FFI_ICHECK(GetConstInt(seq_lengths->shape[0]) == GetConstInt(x->shape[batch_axis])) << "For reverse_sequnece seq_lengths size should match with dimension of batch axis" << ", but got dimension of batch_axis = " << GetConstInt(x->shape[batch_axis]) << ", and seq_length size = " << GetConstInt(seq_lengths->shape[0]); - ICHECK((0 <= batch_axis) && (batch_axis < static_cast(x->shape.size()))) + TVM_FFI_ICHECK((0 <= batch_axis) && (batch_axis < static_cast(x->shape.size()))) << "batch_axis=" << batch_axis_inp << " is invalid for the " << static_cast(x->shape.size()) << "-dimensional input tensor"; } @@ -288,7 +290,7 @@ inline Tensor reverse_sequence(const Tensor& x, const Tensor& seq_lengths, int s if (seq_axis < 0) { seq_axis = static_cast(x->shape.size()) + seq_axis; } - ICHECK((0 <= seq_axis) && (seq_axis < static_cast(x->shape.size()))) + TVM_FFI_ICHECK((0 <= seq_axis) && (seq_axis < static_cast(x->shape.size()))) << "seq_axis=" << seq_axis_inp << " is invalid for the " << static_cast(x->shape.size()) << "-dimensional input tensor"; @@ -479,12 +481,13 @@ inline Tensor squeeze(const Tensor& x, ffi::Optional> opt_ax inline Tensor concatenate(const ffi::Array& inputs, int axis = 0, std::string name = "T_concat", std::string tag = kInjective) { int ndim = static_cast(inputs[0]->shape.size()); - ICHECK(-ndim <= axis && axis < ndim) << "concatenate only accepts `axis` in [-ndim, ndim)" - << ", but got axis = " << axis << ", and ndim = " << ndim; + TVM_FFI_ICHECK(-ndim <= axis && axis < ndim) + << "concatenate only accepts `axis` in [-ndim, ndim)" + << ", but got axis = " << axis << ", and ndim = " << ndim; if (axis < 0) { axis += ndim; } - ICHECK_LT(axis, inputs[0]->shape.size()) << "axis out of bounds"; + TVM_FFI_ICHECK_LT(axis, inputs[0]->shape.size()) << "axis out of bounds"; ffi::Array axis_sizes; for (auto t : inputs) { @@ -538,13 +541,13 @@ inline Tensor concatenate(const ffi::Array& inputs, int axis = 0, inline Tensor stack(const ffi::Array& inputs, int axis = 0, std::string name = "T_stack", std::string tag = kInjective) { int ndim = static_cast(inputs[0]->shape.size()); - ICHECK(-ndim - 1 <= axis && axis <= ndim) + TVM_FFI_ICHECK(-ndim - 1 <= axis && axis <= ndim) << "stack only accepts `axis` in [-ndim, ndim)" << ", but got axis = " << axis << ", and ndim = " << ndim; if (axis < 0) { axis += ndim + 1; } - ICHECK_LT(axis, inputs[0]->shape.size() + 1) << "axis out of bounds"; + TVM_FFI_ICHECK_LT(axis, inputs[0]->shape.size() + 1) << "axis out of bounds"; const int stack_size = static_cast(inputs.size()); ffi::Array out_shape; @@ -587,7 +590,7 @@ inline ffi::Array split_indices_array(const Tensor& x, ffi::Array(x->shape.size()); } - ICHECK_LT(axis, x->shape.size()) << "axis out of bounds"; + TVM_FFI_ICHECK_LT(axis, x->shape.size()) << "axis out of bounds"; auto src_axis_size = x->shape[axis]; std::vector begin_ids; @@ -597,7 +600,7 @@ inline ffi::Array split_indices_array(const Tensor& x, ffi::Array(); auto back_node = begin_ids.back().as(); if (idx_node && back_node) { - ICHECK_GT(idx_node->value, back_node->value) << "split_indices must be sorted"; + TVM_FFI_ICHECK_GT(idx_node->value, back_node->value) << "split_indices must be sorted"; } begin_ids.push_back(idx); } @@ -716,14 +719,14 @@ inline te::Tensor dynamic_strided_slice_with_axes( bool assume_inbound = true, std::string name = "T_dynamic_strided_slice_with_axes", std::string tag = kInjective) { const size_t src_tensor_dim = x->shape.size(); - ICHECK_EQ(begin.size(), end.size()); - ICHECK_EQ(begin.size(), strides.size()); - ICHECK_EQ(begin.size(), axes.size()); - ICHECK_LE(begin.size(), src_tensor_dim); + TVM_FFI_ICHECK_EQ(begin.size(), end.size()); + TVM_FFI_ICHECK_EQ(begin.size(), strides.size()); + TVM_FFI_ICHECK_EQ(begin.size(), axes.size()); + TVM_FFI_ICHECK_LE(begin.size(), src_tensor_dim); for (const auto& axis_imm : axes) { int axis = axis_imm->value; - ICHECK_LT(axis, src_tensor_dim); + TVM_FFI_ICHECK_LT(axis, src_tensor_dim); } arith::Analyzer analyzer; @@ -773,11 +776,11 @@ inline Tensor dynamic_strided_slice(const Tensor& x, const ffi::Array& std::string name = "T_dynamic_strided_slice", std::string tag = kInjective) { const size_t src_tensor_dim = x->shape.size(); - ICHECK_LE(begin.size(), src_tensor_dim); - ICHECK_LE(end.size(), src_tensor_dim); - ICHECK_LE(strides.size(), src_tensor_dim); - ICHECK_EQ(begin.size(), end.size()); - ICHECK_EQ(begin.size(), strides.size()); + TVM_FFI_ICHECK_LE(begin.size(), src_tensor_dim); + TVM_FFI_ICHECK_LE(end.size(), src_tensor_dim); + TVM_FFI_ICHECK_LE(strides.size(), src_tensor_dim); + TVM_FFI_ICHECK_EQ(begin.size(), end.size()); + TVM_FFI_ICHECK_EQ(begin.size(), strides.size()); const size_t num_slice_axes = begin.size(); ffi::Array out_shape; @@ -835,8 +838,8 @@ inline te::Tensor dynamic_strided_slice(const te::Tensor& x, const te::Tensor& b std::string tag = topi::kInjective) { DataType index_dtype = begin->shape[0]->dtype; const int64_t num_dynamic_axes = begin->shape[0].as()->value; - ICHECK_EQ(end->shape[0].as()->value, num_dynamic_axes); - ICHECK_EQ(strides->shape[0].as()->value, num_dynamic_axes); + TVM_FFI_ICHECK_EQ(end->shape[0].as()->value, num_dynamic_axes); + TVM_FFI_ICHECK_EQ(strides->shape[0].as()->value, num_dynamic_axes); ffi::Array begin_expr, end_expr, strides_expr; for (int64_t i = 0; i < num_dynamic_axes; ++i) { @@ -868,7 +871,8 @@ inline ffi::Array StridedSliceOutputShape(const ffi::Array& const ffi::Array& strides, const ffi::Array& axes, const std::string& slice_mode) { - ICHECK(axes.size() == begin.size() && axes.size() == end.size() && axes.size() == strides.size()); + TVM_FFI_ICHECK(axes.size() == begin.size() && axes.size() == end.size() && + axes.size() == strides.size()); std::vector begin_vec, end_vec, strides_vec; std::tie(begin_vec, end_vec, strides_vec) = ConvertToVec(begin, end, strides, slice_mode); auto begin_canonicalized = StridedSliceCanonicalizeBegin(ishape, begin_vec, strides_vec, axes, @@ -901,8 +905,9 @@ inline Tensor strided_slice_with_axes(const Tensor& x, const ffi::Array std::string name = "T_strided_slice_with_axes", std::string tag = kInjective) { const size_t src_tensor_dim = x->shape.size(); - ICHECK(axes.size() <= src_tensor_dim); - ICHECK(axes.size() == begin.size() && axes.size() == end.size() && axes.size() == strides.size()); + TVM_FFI_ICHECK(axes.size() <= src_tensor_dim); + TVM_FFI_ICHECK(axes.size() == begin.size() && axes.size() == end.size() && + axes.size() == strides.size()); std::vector begin_vec, end_vec, strides_vec; std::tie(begin_vec, end_vec, strides_vec) = ConvertToVec(begin, end, strides, slice_mode); @@ -989,11 +994,11 @@ inline ffi::Array split_n_sections(const Tensor& x, int num_sections, in if (axis < 0) { axis += static_cast(x->shape.size()); } - ICHECK_LT(axis, x->shape.size()) << "axis out of bounds"; + TVM_FFI_ICHECK_LT(axis, x->shape.size()) << "axis out of bounds"; auto src_axis_size = x->shape[axis]; - ICHECK_GT(num_sections, 0) << "Slice count must be > 0"; + TVM_FFI_ICHECK_GT(num_sections, 0) << "Slice count must be > 0"; ffi::Array split_indices; auto seg_size = indexdiv(src_axis_size + num_sections - 1, num_sections); @@ -1082,8 +1087,9 @@ inline Tensor take(const Tensor& a, const Tensor& indices, int batch_dims, inline Tensor sequence_mask(const Tensor& data, const Tensor& valid_length, double mask_value, int axis, std::string name = "T_sequence_mask", std::string tag = kInjective) { - ICHECK(axis == 0 || axis == 1) << "axis must be either 0 or 1"; - ICHECK_EQ(valid_length->shape.size(), 1) << "valid_length must have ndim=1, i.e., (batch_size,)."; + TVM_FFI_ICHECK(axis == 0 || axis == 1) << "axis must be either 0 or 1"; + TVM_FFI_ICHECK_EQ(valid_length->shape.size(), 1) + << "valid_length must have ndim=1, i.e., (batch_size,)."; auto length_dim = data->shape[axis]; auto batch_dim = data->shape[1 - axis]; ffi::Array out_shape = data->shape; @@ -1123,8 +1129,8 @@ inline Tensor take(const Tensor& a, ffi::Variant indices, int if (axis < 0) { axis += static_cast(a->shape.size()); } - ICHECK_GE(axis, 0) << "axis out of bounds"; - ICHECK_LT(axis, a->shape.size()) << "axis out of bounds"; + TVM_FFI_ICHECK_GE(axis, 0) << "axis out of bounds"; + TVM_FFI_ICHECK_LT(axis, a->shape.size()) << "axis out of bounds"; auto axis_dim = a->shape[axis]; auto indices_shape = [&]() -> ffi::Array { if (auto tensor = indices.as()) { @@ -1138,21 +1144,22 @@ inline Tensor take(const Tensor& a, ffi::Variant indices, int int batch_dims_ = batch_dims; if (batch_dims_ != 0) { - ICHECK_GE(batch_dims_, -indices_len) << "batch_dims out of bounds"; - ICHECK_LE(batch_dims_, indices_len) << "batch_dims out of bounds"; + TVM_FFI_ICHECK_GE(batch_dims_, -indices_len) << "batch_dims out of bounds"; + TVM_FFI_ICHECK_LE(batch_dims_, indices_len) << "batch_dims out of bounds"; if (batch_dims_ < 0) { batch_dims_ = indices_len + batch_dims_; } - ICHECK_LT(batch_dims_, a->shape.size()) << "batch_dims out of bounds"; - ICHECK_LE(batch_dims_, axis) << "batch_dims must be less than or equal to axis"; + TVM_FFI_ICHECK_LT(batch_dims_, a->shape.size()) << "batch_dims out of bounds"; + TVM_FFI_ICHECK_LE(batch_dims_, axis) << "batch_dims must be less than or equal to axis"; for (int i = 0; i < batch_dims_; ++i) { auto addr1 = a->shape[i]; auto addr2 = indices_shape[i]; auto v1 = static_cast(&addr1)->get()->value; auto v2 = static_cast(&addr2)->get()->value; - ICHECK_EQ(v1, v2) << "a.shape[" << i << "] should be equal to indices.shape[" << i << "]"; + TVM_FFI_ICHECK_EQ(v1, v2) << "a.shape[" << i << "] should be equal to indices.shape[" << i + << "]"; } } @@ -1177,10 +1184,10 @@ inline Tensor take(const Tensor& a, ffi::Variant indices, int if (auto tensor = indices.as()) { return tensor.value()(indices_position); } else if (auto prim = indices.as()) { - ICHECK_EQ(indices_position.size(), 0); + TVM_FFI_ICHECK_EQ(indices_position.size(), 0); return prim.value(); } else { - LOG(FATAL) << "Variant did not contain either allowed type"; + TVM_FFI_THROW(InternalError) << "Variant did not contain either allowed type"; } }; @@ -1309,8 +1316,8 @@ inline Tensor take(const Tensor& a, ffi::Variant indices, int */ inline Tensor where(const Tensor& condition, const Tensor& x, const Tensor& y, std::string name = "T_where", std::string tag = kBroadcast) { - ICHECK_EQ(x->dtype, y->dtype) << "x and y must have the same dtype: " << x->dtype << " vs " - << y->dtype; + TVM_FFI_ICHECK_EQ(x->dtype, y->dtype) + << "x and y must have the same dtype: " << x->dtype << " vs " << y->dtype; auto get_out_shape = [&]() { auto bh1 = detail::BroadcastShape(x->shape, y->shape); ffi::Array common_shape1(bh1.common_shape.begin(), bh1.common_shape.end()); @@ -1350,11 +1357,11 @@ inline Tensor where(const Tensor& condition, const Tensor& x, const Tensor& y, inline Tensor repeat(const Tensor& x, int repeats, int axis, std::string name = "T_repeat", std::string tag = kBroadcast) { int ndim = static_cast(x->shape.size()); - ICHECK(-ndim - 1 <= axis && axis <= ndim) + TVM_FFI_ICHECK(-ndim - 1 <= axis && axis <= ndim) << "repeat only accepts `axis` in [-data.ndim - 1, data.ndim]" << ", but got axis = " << axis << ", and data.ndim = " << ndim; - ICHECK(repeats >= 1) << "repeat only accepts `repeats >= 1`" - << ", but got repeats = " << repeats; + TVM_FFI_ICHECK(repeats >= 1) << "repeat only accepts `repeats >= 1`" + << ", but got repeats = " << repeats; if (axis < 0) { // Calculate offset from last dimension axis += ndim; @@ -1492,18 +1499,18 @@ inline Tensor gather(const Tensor& data, int axis, const Tensor& indices, std::string name = "T_gather", std::string tag = kInjective) { size_t ndim_d = data->shape.size(); size_t ndim_i = indices->shape.size(); - ICHECK_GE(ndim_d, 1) << "Cannot gather from a scalar."; - ICHECK_EQ(ndim_d, ndim_i); + TVM_FFI_ICHECK_GE(ndim_d, 1) << "Cannot gather from a scalar."; + TVM_FFI_ICHECK_EQ(ndim_d, ndim_i); if (axis < 0) { axis += ndim_d; } - ICHECK_GE(axis, 0); - ICHECK_LT(axis, ndim_d); + TVM_FFI_ICHECK_GE(axis, 0); + TVM_FFI_ICHECK_LT(axis, ndim_d); if (indices->shape[axis].as()) { size_t indices_dim_i = static_cast(GetConstInt(indices->shape[axis])); - ICHECK_GE(indices_dim_i, 1); + TVM_FFI_ICHECK_GE(indices_dim_i, 1); } - ICHECK(indices->dtype.is_int() || indices->dtype.is_uint()); + TVM_FFI_ICHECK(indices->dtype.is_int() || indices->dtype.is_uint()); ffi::Array out_shape; for (size_t i = 0; i < ndim_i; ++i) { @@ -1545,10 +1552,10 @@ inline Tensor gather_nd(const Tensor& data, const Tensor& indices, int batch_dim std::string name = "T_gather_nd", std::string tag = kInjective) { size_t ndim_d = data->shape.size(); size_t ndim_i = indices->shape.size(); - ICHECK_GE(ndim_i, 1) << "indices tensor must have at least 1 dimensions"; + TVM_FFI_ICHECK_GE(ndim_i, 1) << "indices tensor must have at least 1 dimensions"; size_t indices_dim0 = static_cast(GetConstInt(indices->shape[0])); - ICHECK_LE(indices_dim0, ndim_d) << "dim 0 of indices tensor must be no more " - << "than dimensions of data tensor"; + TVM_FFI_ICHECK_LE(indices_dim0, ndim_d) << "dim 0 of indices tensor must be no more " + << "than dimensions of data tensor"; ffi::Array out_shape; for (size_t i = 1; i < ndim_i; ++i) { out_shape.push_back(indices->shape[i]); @@ -1626,8 +1633,8 @@ inline tvm::te::Tensor matmul(const tvm::te::Tensor& A, const tvm::te::Tensor& B */ inline Tensor tensordot(const Tensor& A, const tvm::te::Tensor& B, int axes = 2, std::string name = "T_tensordot", std::string tag = kMatMul) { - ICHECK_GE(A->shape.size(), axes); - ICHECK_GE(B->shape.size(), axes); + TVM_FFI_ICHECK_GE(A->shape.size(), axes); + TVM_FFI_ICHECK_GE(B->shape.size(), axes); ffi::Array output_shape(A->shape.begin(), A->shape.end() + (-axes)); for (auto it = B->shape.begin() + axes; it != B->shape.end(); ++it) output_shape.push_back(*it); @@ -1673,7 +1680,7 @@ inline Tensor tensordot(const Tensor& A, const tvm::te::Tensor& B, int axes = 2, inline Tensor tensordot(const Tensor& A, const tvm::te::Tensor& B, ffi::Array A_axes, ffi::Array B_axes, std::string name = "T_tensordot", std::string tag = kMatMul) { - ICHECK_EQ(A_axes.size(), B_axes.size()); + TVM_FFI_ICHECK_EQ(A_axes.size(), B_axes.size()); auto A_axes_val = GetConstIntValues(A_axes, "A_axes"); auto B_axes_val = GetConstIntValues(B_axes, "B_axes"); @@ -1798,11 +1805,11 @@ inline Tensor layout_transform(const Tensor& src, const std::string& src_layout, return src; } - ICHECK(src_layout_struct.defined() && dst_layout_struct.defined()) + TVM_FFI_ICHECK(src_layout_struct.defined() && dst_layout_struct.defined()) << "cannot convert from/to undefined layout"; auto layout_converter = tir::BijectiveLayout(src_layout_struct, dst_layout_struct); - ICHECK(layout_converter.defined()) + TVM_FFI_ICHECK(layout_converter.defined()) << "cannot convert from " << src_layout << " to " << dst_layout; ffi::Array dst_shape = layout_converter.ForwardShape(src->shape); @@ -1846,7 +1853,7 @@ inline void parse_auto_scheduler_layout(const ffi::String& layout, ffi::Array src_indices; for (const std::string& src_axis : src_axes) { PrimExpr src_index = 0; - CHECK_EQ(dst_indices_expr.size(), dst_axes.size()); + TVM_FFI_ICHECK_EQ(dst_indices_expr.size(), dst_axes.size()); for (size_t i = 0; i < dst_axes.size(); ++i) { if (dst_axes[i] == src_axis) { src_index = src_index * dst_shape[i] + dst_indices_expr[i]; @@ -2066,10 +2073,11 @@ inline Tensor sparse_to_dense(const Tensor& sparse_indices, const PrimExpr& default_value, const std::string name = "T_sparse_to_dense", const std::string tag = kInjective) { - ICHECK(sparse_indices->dtype.is_int()) << "sparse_indices only accepts integer values"; - ICHECK_LE(sparse_indices->shape.size(), 3) + TVM_FFI_ICHECK(sparse_indices->dtype.is_int()) << "sparse_indices only accepts integer values"; + TVM_FFI_ICHECK_LE(sparse_indices->shape.size(), 3) << "sparse_indices tensor should be 0D, 1D, or 2D only"; - ICHECK_LE(sparse_values->shape.size(), 2) << "sparse_values tensor should be 0D or 1D only"; + TVM_FFI_ICHECK_LE(sparse_values->shape.size(), 2) + << "sparse_values tensor should be 0D or 1D only"; const auto rank_sparse_indices = static_cast(sparse_indices->shape.size()); ffi::Array oshape; @@ -2172,7 +2180,7 @@ inline Tensor matrix_set_diag(const Tensor& input, const Tensor& diagonal, int k inline Tensor adv_index(const Tensor& data, const ffi::Array& indices, const std::string name = "advanced_index", const std::string tag = kInjective) { - ICHECK_LE(indices.size(), data->shape.size()) << "too many indices for data!"; + TVM_FFI_ICHECK_LE(indices.size(), data->shape.size()) << "too many indices for data!"; ffi::Array oshape; ffi::Array broadcast_shape; ffi::Array bindices; @@ -2227,18 +2235,18 @@ inline te::Tensor dynamic_strided_slice(const te::Tensor& x, const te::Tensor& b std::string name = "T_strided_slice_dynamic", std::string tag = kInjective) { const size_t num_dynamic_axes = x.ndim(); - ICHECK_EQ(begin.ndim(), 1); - ICHECK_EQ(end.ndim(), 1); - ICHECK_EQ(strides.ndim(), 1); + TVM_FFI_ICHECK_EQ(begin.ndim(), 1); + TVM_FFI_ICHECK_EQ(end.ndim(), 1); + TVM_FFI_ICHECK_EQ(strides.ndim(), 1); const auto* len_begin = begin->shape[0].as(); const auto* len_end = end->shape[0].as(); const auto* len_strides = strides->shape[0].as(); - ICHECK(len_begin); - ICHECK(len_end); - ICHECK(len_strides); - ICHECK_EQ(len_begin->value, num_dynamic_axes); - ICHECK_EQ(len_end->value, num_dynamic_axes); - ICHECK_EQ(len_strides->value, num_dynamic_axes); + TVM_FFI_ICHECK(len_begin); + TVM_FFI_ICHECK(len_end); + TVM_FFI_ICHECK(len_strides); + TVM_FFI_ICHECK_EQ(len_begin->value, num_dynamic_axes); + TVM_FFI_ICHECK_EQ(len_end->value, num_dynamic_axes); + TVM_FFI_ICHECK_EQ(len_strides->value, num_dynamic_axes); return te::compute( output_shape, diff --git a/src/arith/analyzer.cc b/src/arith/analyzer.cc index f6f0b9f4d8df..cae6b2d7c6b5 100644 --- a/src/arith/analyzer.cc +++ b/src/arith/analyzer.cc @@ -54,7 +54,7 @@ void Analyzer::Bind(const Var& var, const PrimExpr& expr, bool allow_override) { } void Analyzer::Bind(const Var& var, const Range& range, bool allow_override) { - ICHECK(range.defined()); + TVM_FFI_ICHECK(range.defined()); if (tir::is_one(range->extent)) { this->Bind(var, range->min, allow_override); } else { @@ -123,7 +123,7 @@ void Analyzer::Bind(const ffi::Map& variables, bool allow_override) } void ConstraintContext::EnterWithScope() { - ICHECK(recovery_functions_.size() == 0); + TVM_FFI_ICHECK(recovery_functions_.size() == 0); // entering the scope. recovery_functions_.push_back(analyzer_->const_int_bound.EnterConstraint(constraint_)); recovery_functions_.push_back(analyzer_->modular_set.EnterConstraint(constraint_)); @@ -301,7 +301,7 @@ TVM_FFI_STATIC_INIT_BLOCK() { } else if (args.size() == 2) { *ret = self->Simplify(args[0].cast(), args[1].cast()); } else { - LOG(FATAL) << "Invalid size of argument (" << args.size() << ")"; + TVM_FFI_THROW(InternalError) << "Invalid size of argument (" << args.size() << ")"; } }); } else if (name == "rewrite_simplify") { diff --git a/src/arith/bound_deducer.cc b/src/arith/bound_deducer.cc index eb9edca36341..e23219465376 100644 --- a/src/arith/bound_deducer.cc +++ b/src/arith/bound_deducer.cc @@ -263,7 +263,7 @@ CompareOp BoundDeducer::ReverseOp(CompareOp comp_op) { case kLess: return kGreater; default: - LOG(FATAL) << "Not a valid compare op"; + TVM_FFI_THROW(InternalError) << "Not a valid compare op"; } } diff --git a/src/arith/canonical_simplify.cc b/src/arith/canonical_simplify.cc index f321d761198c..5192bc1ad179 100644 --- a/src/arith/canonical_simplify.cc +++ b/src/arith/canonical_simplify.cc @@ -60,7 +60,7 @@ inline PrimExpr ModImpl(PrimExpr a, PrimExpr b, DivMode mode) { if (mode == kTruncDiv) { return truncmod(a, b); } else { - ICHECK_EQ(mode, kFloorDiv); + TVM_FFI_ICHECK_EQ(mode, kFloorDiv); return floormod(a, b); } } @@ -69,7 +69,7 @@ inline PrimExpr DivImpl(PrimExpr a, PrimExpr b, DivMode mode) { if (mode == kTruncDiv) { return truncdiv(a, b); } else { - ICHECK_EQ(mode, kFloorDiv); + TVM_FFI_ICHECK_EQ(mode, kFloorDiv); return floordiv(a, b); } } @@ -120,7 +120,9 @@ class SplitExprNode : public CanonicalExprNode { DivMode div_mode{kTruncDiv}; /*! \brief verify that this is a valid entry. */ - void Verify() const { ICHECK(upper_factor == kPosInf || upper_factor % lower_factor == 0); } + void Verify() const { + TVM_FFI_ICHECK(upper_factor == kPosInf || upper_factor % lower_factor == 0); + } PrimExpr NormalizeWithScale(int64_t sscale) const { PrimExpr res = this->index; @@ -136,7 +138,7 @@ class SplitExprNode : public CanonicalExprNode { } sscale *= this->scale; if (sscale != 1) { - ICHECK(!dtype.is_uint() || sscale > 0); + TVM_FFI_ICHECK(!dtype.is_uint() || sscale > 0); res = res * make_const(dtype, sscale); } return res; @@ -180,7 +182,7 @@ class SplitExprNode : public CanonicalExprNode { } } if (this->scale != 1) { - ICHECK(!this->dtype.is_uint() || this->scale > 0); + TVM_FFI_ICHECK(!this->dtype.is_uint() || this->scale > 0); res = res * make_const(this->dtype, this->scale); if (!CastIsSafe(dtype, res, analyzer)) { return false; @@ -278,10 +280,10 @@ class SumExprNode : public CanonicalExprNode { * \param scale The scale to be applied. */ void DivideBy(int64_t scale) { - ICHECK_EQ(this->base % scale, 0); + TVM_FFI_ICHECK_EQ(this->base % scale, 0); this->base /= scale; for (size_t i = 0; i < this->args.size(); ++i) { - ICHECK_EQ(args[i]->scale % scale, 0); + TVM_FFI_ICHECK_EQ(args[i]->scale % scale, 0); args[i].CopyOnWrite()->scale /= scale; } } @@ -700,7 +702,7 @@ class CanonicalSimplifier::Impl : public RewriteSimplifier::Impl { return expr; } expr = ToSplitExpr(Normalize(expr)); - ICHECK(expr->DivModeCompatibleTo(div_mode)); + TVM_FFI_ICHECK(expr->DivModeCompatibleTo(div_mode)); expr.CopyOnWrite()->div_mode = div_mode; return expr; } @@ -843,7 +845,7 @@ void CanonicalSimplifier::Impl::SeparateDivisibleParts(const SumExprNode* psum, } SplitExpr CanonicalSimplifier::Impl::SplitDivConst(SplitExpr lhs, int64_t cval, DivMode div_mode) { - ICHECK_GT(cval, 0); + TVM_FFI_ICHECK_GT(cval, 0); lhs = ConvertDivMode(lhs, div_mode); // the following rule works for both floordiv and truncdiv @@ -877,8 +879,8 @@ SplitExpr CanonicalSimplifier::Impl::SplitDivConst(SplitExpr lhs, int64_t cval, } // directly return the split with cval == 1 lhs = ToSplitExpr(Normalize(lhs)); - ICHECK(lhs->DivModeCompatibleTo(div_mode)); - ICHECK_EQ(lhs->scale, 1); + TVM_FFI_ICHECK(lhs->DivModeCompatibleTo(div_mode)); + TVM_FFI_ICHECK_EQ(lhs->scale, 1); lhs.CopyOnWrite()->lower_factor *= cval; lhs.CopyOnWrite()->div_mode = div_mode; return lhs; @@ -1069,7 +1071,7 @@ PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const FloorDivNode* op) { } SplitExpr CanonicalSimplifier::Impl::SplitModConst(SplitExpr lhs, int64_t cval, DivMode div_mode) { - ICHECK_GT(cval, 0); + TVM_FFI_ICHECK_GT(cval, 0); lhs = ConvertDivMode(lhs, div_mode); if (lhs->scale % cval == 0) { @@ -1114,9 +1116,9 @@ SplitExpr CanonicalSimplifier::Impl::SplitModConst(SplitExpr lhs, int64_t cval, } // Normalize the value. lhs = ToSplitExpr(Normalize(lhs)); - ICHECK(lhs->DivModeCompatibleTo(div_mode)); - ICHECK_EQ(lhs->scale, 1); - ICHECK_EQ(lhs->lower_factor, 1); + TVM_FFI_ICHECK(lhs->DivModeCompatibleTo(div_mode)); + TVM_FFI_ICHECK_EQ(lhs->scale, 1); + TVM_FFI_ICHECK_EQ(lhs->lower_factor, 1); lhs.CopyOnWrite()->div_mode = div_mode; lhs.CopyOnWrite()->upper_factor = cval; return lhs; @@ -1157,7 +1159,7 @@ PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const ModNode* op) { // continue to use logic below. a = extra; psum = a.as(); - ICHECK(psum != nullptr); + TVM_FFI_ICHECK(psum != nullptr); } } } @@ -1226,7 +1228,7 @@ PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const FloorModNode* op) { // continue to use logic below. a = extra; psum = a.as(); - ICHECK(psum != nullptr); + TVM_FFI_ICHECK(psum != nullptr); } } // Simplify the offset constant if necessary. @@ -1409,7 +1411,7 @@ PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const LTNode* op) { SumExpr divisible, extra; SeparateDivisibleParts(lhs, gcd, &divisible, &extra); DataType dtype = divisible->dtype; - ICHECK(extra->dtype == dtype); + TVM_FFI_ICHECK(extra->dtype == dtype); PrimExpr normal_extra = extra->Normalize(); if (this->analyzer_->CanProve(normal_extra < make_const(dtype, gcd)) && this->analyzer_->CanProve(normal_extra >= make_const(dtype, 0))) { diff --git a/src/arith/conjunctive_normal_form.cc b/src/arith/conjunctive_normal_form.cc index 1c5f31a913a1..7a87a1fbabf5 100644 --- a/src/arith/conjunctive_normal_form.cc +++ b/src/arith/conjunctive_normal_form.cc @@ -228,7 +228,7 @@ AndOfOrs::Key AndOfOrs::GetKey(const PrimExpr& expr) { PrimExpr AndOfOrs::GetExpr(AndOfOrs::Key key) const { auto it = key_to_expr_.find(key); - ICHECK(it != key_to_expr_.end()); + TVM_FFI_ICHECK(it != key_to_expr_.end()); return it->second; } diff --git a/src/arith/const_fold.h b/src/arith/const_fold.h index 5118204db69c..4128e43e6e25 100644 --- a/src/arith/const_fold.h +++ b/src/arith/const_fold.h @@ -154,8 +154,8 @@ inline ffi::Optional TryConstFold(PrimExpr a, PrimExpr b) { template <> inline ffi::Optional TryConstFold(PrimExpr a, PrimExpr b) { TVM_ARITH_CONST_PROPAGATION({ - ICHECK(!((pa && pa->dtype.is_uint() && pa->value == 0U) && - (pb && pb->dtype.is_uint() && pb->value > 0U))) + TVM_FFI_ICHECK(!((pa && pa->dtype.is_uint() && pa->value == 0U) && + (pb && pb->dtype.is_uint() && pb->value > 0U))) << "Checked failed. Minuend 's value is 0U and it's dtype is uint " << "while Subtrahend's dtype is uint; which will cause a negative uint"; const DataType& rtype = a.dtype(); @@ -220,7 +220,7 @@ inline ffi::Optional TryConstFold(PrimExpr a, PrimExpr b) { if (pa && pb) { // due to division and mod can have different modes // NOTE: this will assumes truc div. - ICHECK_NE(pb->value, 0) << "Divide by zero"; + TVM_FFI_ICHECK_NE(pb->value, 0) << "Divide by zero"; int64_t res = pa->value / pb->value; return IntImm(rtype, GetFoldResultInt64Repr(res, rtype)); } @@ -229,10 +229,10 @@ inline ffi::Optional TryConstFold(PrimExpr a, PrimExpr b) { } if (pb) { if (pb->value == 1) return a; - ICHECK_NE(pb->value, 0) << "Divide by zero"; + TVM_FFI_ICHECK_NE(pb->value, 0) << "Divide by zero"; } if (fa && fb) { - ICHECK_NE(fb->value, 0) << "Divide by zero"; + TVM_FFI_ICHECK_NE(fb->value, 0) << "Divide by zero"; if (rtype.bits() == 32) { return FloatImm(rtype, GetFoldResultDoubleRepr(static_cast(fa->value) / static_cast(fb->value))); @@ -243,7 +243,7 @@ inline ffi::Optional TryConstFold(PrimExpr a, PrimExpr b) { if (fa && fa->value == 0) return a; if (fb) { if (fb->value == 1) return a; - ICHECK_NE(fb->value, 0) << "Divide by zero"; + TVM_FFI_ICHECK_NE(fb->value, 0) << "Divide by zero"; } }); return std::nullopt; @@ -254,7 +254,7 @@ inline ffi::Optional TryConstFold(PrimExpr a, PrimExpr b) { TVM_INDEX_CONST_PROPAGATION({ const DataType& rtype = a.dtype(); if (pa && pb) { - ICHECK_NE(pb->value, 0) << "Divide by zero"; + TVM_FFI_ICHECK_NE(pb->value, 0) << "Divide by zero"; int64_t res = pa->value % pb->value; return IntImm(rtype, GetFoldResultInt64Repr(res, rtype)); } @@ -263,7 +263,7 @@ inline ffi::Optional TryConstFold(PrimExpr a, PrimExpr b) { } if (pb) { if (pb->value == 1) return tir::make_zero(rtype); - ICHECK_NE(pb->value, 0) << "Divide by zero"; + TVM_FFI_ICHECK_NE(pb->value, 0) << "Divide by zero"; } }); return std::nullopt; @@ -274,7 +274,7 @@ inline ffi::Optional TryConstFold(PrimExpr a, PrimExpr TVM_ARITH_CONST_PROPAGATION({ const DataType& rtype = a.dtype(); if (pa && pb) { - ICHECK_NE(pb->value, 0) << "Divide by zero"; + TVM_FFI_ICHECK_NE(pb->value, 0) << "Divide by zero"; int64_t res = arith::floordiv(pa->value, pb->value); return IntImm(rtype, GetFoldResultInt64Repr(res, rtype)); } @@ -283,7 +283,7 @@ inline ffi::Optional TryConstFold(PrimExpr a, PrimExpr } if (pb) { if (pb->value == 1) return a; - ICHECK_NE(pb->value, 0) << "Divide by zero"; + TVM_FFI_ICHECK_NE(pb->value, 0) << "Divide by zero"; } if (fa && fb && fb->value != 0) { if (rtype.bits() == 32) { @@ -298,7 +298,7 @@ inline ffi::Optional TryConstFold(PrimExpr a, PrimExpr if (fa && fa->value == 0) return a; if (fb) { if (fb->value == 1) return a; - ICHECK_NE(fb->value, 0) << "Divide by zero"; + TVM_FFI_ICHECK_NE(fb->value, 0) << "Divide by zero"; } }); return std::nullopt; @@ -309,7 +309,7 @@ inline ffi::Optional TryConstFold(PrimExpr a, PrimExpr TVM_INDEX_CONST_PROPAGATION({ const DataType& rtype = a.dtype(); if (pa && pb) { - ICHECK_NE(pb->value, 0) << "Divide by zero"; + TVM_FFI_ICHECK_NE(pb->value, 0) << "Divide by zero"; int64_t res = arith::floormod(pa->value, pb->value); return IntImm(rtype, GetFoldResultInt64Repr(res, rtype)); } @@ -318,7 +318,7 @@ inline ffi::Optional TryConstFold(PrimExpr a, PrimExpr } if (pb) { if (pb->value == 1) return tir::make_zero(rtype); - ICHECK_NE(pb->value, 0) << "Divide by zero"; + TVM_FFI_ICHECK_NE(pb->value, 0) << "Divide by zero"; } }); return std::nullopt; diff --git a/src/arith/const_int_bound.cc b/src/arith/const_int_bound.cc index 6dd029e136ea..78a456784c8d 100644 --- a/src/arith/const_int_bound.cc +++ b/src/arith/const_int_bound.cc @@ -129,7 +129,7 @@ class ConstIntBoundAnalyzer::Impl if (!allow_override) { auto it = var_map_.find(var); if (it != var_map_.end()) { - ICHECK(it->second == info) + TVM_FFI_ICHECK(it->second == info) << "Trying to update var \'" << var << "\'" << " with a different const bound: " << "original=" << ConstIntBound(it->second.min_value, it->second.max_value) @@ -175,7 +175,7 @@ class ConstIntBoundAnalyzer::Impl auto val = bound_->find(expr); if (val != bound_->end()) { auto everything = Everything(expr->dtype); - ICHECK( + TVM_FFI_ICHECK( (val->second->min_value == res.min_value && val->second->max_value == res.max_value) || (val->second->min_value == everything.min_value && val->second->max_value == everything.max_value)) @@ -226,7 +226,7 @@ class ConstIntBoundAnalyzer::Impl * \return The processed entry */ Entry AssumeNoZeroDivisor(Entry divisor) { - ICHECK(!divisor.is_const(0)) << "Find divide by zero"; + TVM_FFI_ICHECK(!divisor.is_const(0)) << "Find divide by zero"; // NOTE: here we make the assumption that // divide by zero won't happen in a valid program // this is important for us to get a lot of symbolic shape bound right @@ -234,7 +234,7 @@ class ConstIntBoundAnalyzer::Impl // when mod or divide of n occur, the intention is actually n > 0 if (divisor.min_value == 0) { divisor.min_value = 1; - ICHECK_GE(divisor.max_value, 1); + TVM_FFI_ICHECK_GE(divisor.max_value, 1); } return divisor; } @@ -316,7 +316,7 @@ class ConstIntBoundAnalyzer::Impl std::min(std::max(a.max_value, (int64_t)0), b_max_cap)); } } else { - ICHECK(!b.is_const(0)) << "mod by zero"; + TVM_FFI_ICHECK(!b.is_const(0)) << "mod by zero"; // mod by negative value is rare, // and we just use the simpliest rule. return Everything(op->dtype); @@ -387,7 +387,7 @@ class ConstIntBoundAnalyzer::Impl return MakeBound(0, b_max_cap); } } else { - ICHECK(!b.is_const(0)) << "floormod by zero"; + TVM_FFI_ICHECK(!b.is_const(0)) << "floormod by zero"; int64_t b_min_cap = InfAwareAdd(b.min_value, 1); int64_t b_max_cap = InfAwareAdd(b.max_value, -1); return Intersect(MakeBound(std::min(static_cast(0), b_min_cap), @@ -590,11 +590,11 @@ class ConstIntBoundAnalyzer::Impl */ static int64_t InfAwareAdd(int64_t x, int64_t y) { if (x == kPosInf) { - ICHECK(y != kNegInf); + TVM_FFI_ICHECK(y != kNegInf); return kPosInf; } if (x == kNegInf) { - ICHECK(y != kPosInf); + TVM_FFI_ICHECK(y != kPosInf); return kNegInf; } if (y == kPosInf || y == kNegInf) return y; @@ -622,7 +622,7 @@ class ConstIntBoundAnalyzer::Impl * \return the result. */ static int64_t InfAwareDiv(int64_t x, int64_t y) { - ICHECK_NE(y, 0); + TVM_FFI_ICHECK_NE(y, 0); if (x == kPosInf || x == kNegInf) { if (y > 0) return x; return -x; @@ -636,7 +636,7 @@ class ConstIntBoundAnalyzer::Impl * \return the result. */ static int64_t InfAwareFloorDiv(int64_t x, int64_t y) { - ICHECK_NE(y, 0); + TVM_FFI_ICHECK_NE(y, 0); if (x == kPosInf || x == kNegInf) { if (y > 0) return x; return -x; diff --git a/src/arith/detect_common_subexpr.cc b/src/arith/detect_common_subexpr.cc index ce7f1ec4c586..0f116b90d3e4 100644 --- a/src/arith/detect_common_subexpr.cc +++ b/src/arith/detect_common_subexpr.cc @@ -35,8 +35,8 @@ using namespace tir; ffi::Map DetectCommonSubExpr(const PrimExpr& e, int thresh) { // Check the threshold in the range of size_t - CHECK_GE(thresh, std::numeric_limits::min()); - CHECK_LE(thresh, std::numeric_limits::max()); + TVM_FFI_ICHECK_GE(thresh, std::numeric_limits::min()); + TVM_FFI_ICHECK_LE(thresh, std::numeric_limits::max()); size_t repeat_thr = static_cast(thresh); auto IsEligibleComputation = [](const PrimExpr& expr) { return (SideEffect(expr) <= CallEffectKind::kPure && CalculateExprComplexity(expr) > 1 && diff --git a/src/arith/domain_touched.cc b/src/arith/domain_touched.cc index 3fc6d34b7071..dbfa334107ec 100644 --- a/src/arith/domain_touched.cc +++ b/src/arith/domain_touched.cc @@ -85,7 +85,8 @@ class BufferTouchedDomain final : public IRVisitorWithAnalyzer { } else if (consider_stores) { bounds = std::get(kv->second).set; } else { - CHECK(false) << "Must consider at least on of either loads and stores, but both are false"; + TVM_FFI_ICHECK(false) + << "Must consider at least on of either loads and stores, but both are false"; } for (size_t i = 0; i < bounds.size(); ++i) { ret.push_back(arith::Union(bounds[i]).CoverRange(none)); diff --git a/src/arith/int_constraints.cc b/src/arith/int_constraints.cc index 81c845906c5e..3f4048bfd191 100644 --- a/src/arith/int_constraints.cc +++ b/src/arith/int_constraints.cc @@ -51,9 +51,9 @@ ffi::Array AsConditions(const ffi::Array& variables, ffi::Array res; // use variables to keep the order of iteration // so as to get rid of any non-determinism. - ICHECK_EQ(variables.size(), bounds.size()); + TVM_FFI_ICHECK_EQ(variables.size(), bounds.size()); for (const auto v : variables) { - ICHECK(bounds.count(v)); + TVM_FFI_ICHECK(bounds.count(v)); const auto& bnds = bounds[v]; PrimExpr lhs = bnds->coef * v; for (const PrimExpr& rhs : bnds->equal) { @@ -74,7 +74,7 @@ ffi::Array AsConditions(const ffi::Array& variables, IntGroupBounds::IntGroupBounds(PrimExpr coef, ffi::Array lower, ffi::Array equal, ffi::Array upper) { - ICHECK(coef.dtype().is_int() || coef.dtype().is_uint()) + TVM_FFI_ICHECK(coef.dtype().is_int() || coef.dtype().is_uint()) << "Coefficient in IntGroupBounds must be integers"; ObjectPtr node = ffi::make_object(); node->coef = std::move(coef); @@ -195,7 +195,7 @@ Range IntGroupBounds::FindBestRange(const ffi::Map& vranges_addl) co } if (!best_lower.defined()) { - ICHECK(!best_diff_over.defined()); + TVM_FFI_ICHECK(!best_diff_over.defined()); return Range(); } return Range::FromMinExtent(best_lower, analyzer.Simplify(best_diff_over + 1)); @@ -209,7 +209,7 @@ TVM_FFI_STATIC_INIT_BLOCK() { ffi::Array upper) { return IntGroupBounds(coef, lower, equal, upper); }) .def("arith.IntGroupBounds_from_range", IntGroupBounds::FromRange) .def_packed("arith.IntGroupBounds_FindBestRange", [](ffi::PackedArgs args, ffi::Any* ret) { - ICHECK(args.size() == 1 || args.size() == 2); + TVM_FFI_ICHECK(args.size() == 1 || args.size() == 2); auto bounds = args[0].cast(); if (args.size() == 1) { *ret = bounds.FindBestRange(); @@ -235,9 +235,9 @@ IntConstraints::IntConstraints(ffi::Array variables, ffi::Map r if (!ranges.defined()) { ranges = ffi::Map(); } - ICHECK(relations.defined()); + TVM_FFI_ICHECK(relations.defined()); for (const auto& var : variables) { - ICHECK(var.dtype().is_int() || var.dtype().is_uint()) + TVM_FFI_ICHECK(var.dtype().is_int() || var.dtype().is_uint()) << "Variables in IntConstraints must be integers"; } node->variables = std::move(variables); @@ -275,7 +275,7 @@ IntConstraintsTransform::IntConstraintsTransform(IntConstraints src, IntConstrai IntConstraintsTransform IntConstraintsTransform::operator+( const IntConstraintsTransform& other) const { - ICHECK(other->src.same_as(operator->()->dst)); + TVM_FFI_ICHECK(other->src.same_as(operator->()->dst)); ffi::Map dst_to_src; ffi::Map src_to_dst; diff --git a/src/arith/int_set.cc b/src/arith/int_set.cc index 1433ceb70fc0..25c825cbbc7b 100644 --- a/src/arith/int_set.cc +++ b/src/arith/int_set.cc @@ -213,7 +213,7 @@ inline IntervalSet Combine(Analyzer* analyzer, IntervalSet a, Interval if (b->IsEmpty()) return b; if (b->IsSinglePoint()) { if (is_zero(b->min_value)) { - LOG(FATAL) << "Divide by zero in CombineInterval Div"; + TVM_FFI_THROW(InternalError) << "Divide by zero in CombineInterval Div"; } if (is_one(b->min_value)) return a; // no relaxation is needed in here due to set is inclusive @@ -249,7 +249,7 @@ inline IntervalSet Combine(Analyzer* analyzer, IntervalSet a, Interval if (b->IsSinglePoint()) { const PrimExpr& divisor = b->min_value; if (is_zero(divisor)) { - LOG(FATAL) << "Modular by zero in CombineInterval Mod"; + TVM_FFI_THROW(InternalError) << "Modular by zero in CombineInterval Mod"; } // We need to add more bound constraints throughout the code. // The logic below assumes a is non-negative, which usually @@ -276,7 +276,7 @@ inline IntervalSet Combine(Analyzer* analyzer, IntervalSet a, Int if (b->IsEmpty()) return b; if (b->IsSinglePoint()) { if (is_zero(b->min_value)) { - LOG(FATAL) << "Divide by zero in CombineInterval Div"; + TVM_FFI_THROW(InternalError) << "Divide by zero in CombineInterval Div"; } if (is_one(b->min_value)) return a; // no relaxation is needed in here due to set is inclusive @@ -312,7 +312,7 @@ inline IntervalSet Combine(Analyzer* analyzer, IntervalSet a, Int if (b->IsSinglePoint()) { const PrimExpr& divisor = b->min_value; if (is_zero(divisor)) { - LOG(FATAL) << "Modular by zero in CombineInterval Mod"; + TVM_FFI_THROW(InternalError) << "Modular by zero in CombineInterval Mod"; } if (analyzer->CanProveGreaterEqual(divisor, 0)) { if (divisor.as()) { @@ -496,7 +496,7 @@ class IntervalSetEvaluator : public ExprFunctor { IntervalSet VisitExpr_(const OrNode* op) final { return VisitBinaryExpr_(op); } IntervalSet VisitExpr_(const RampNode* op) final { - ICHECK(eval_vec_); + TVM_FFI_ICHECK(eval_vec_); IntervalSet base = Eval(op->base); PVar stride; if (stride.Match(op->stride)) { @@ -532,7 +532,7 @@ class IntervalSetEvaluator : public ExprFunctor { } IntervalSet VisitExpr_(const BroadcastNode* op) final { - ICHECK(eval_vec_); + TVM_FFI_ICHECK(eval_vec_); return VisitExpr(op->value); } @@ -674,12 +674,12 @@ void IntSetAnalyzer::Impl::Update(const Var& var, const IntSet& info, bool can_o if (it != dom_map_.end()) { const IntSet& old_info = (*it).second; - ICHECK(ExprDeepEqual()(old_info.min(), info.min())) + TVM_FFI_ICHECK(ExprDeepEqual()(old_info.min(), info.min())) << "Trying to update var \'" << var << "\'" << " with a different minimum value: " << "original=" << old_info.min() << ", new=" << info.min(); - ICHECK(ExprDeepEqual()(old_info.max(), info.max())) + TVM_FFI_ICHECK(ExprDeepEqual()(old_info.max(), info.max())) << "Trying to update var \'" << var << "\'" << " with a different maximum value: " << "original=" << old_info.max() << ", new=" << info.max(); @@ -739,7 +739,7 @@ std::function IntSetAnalyzer::Impl::EnterConstraint(const PrimExpr& cons dom_constraints_.insert(dom_constraints_.end(), bounds.begin(), bounds.end()); size_t new_size = dom_constraints_.size(); auto frecover = [old_size, new_size, this]() { - ICHECK_EQ(dom_constraints_.size(), new_size); + TVM_FFI_ICHECK_EQ(dom_constraints_.size(), new_size); dom_constraints_.resize(old_size); }; return frecover; @@ -751,7 +751,7 @@ Range IntSet::CoverRange(Range max_range) const { IntSet temp; Analyzer analyzer; const IntervalSetNode* s_int = (*this).as(); - ICHECK(s_int != nullptr); + TVM_FFI_ICHECK(s_int != nullptr); if (s_int->HasUpperBound() && s_int->HasLowerBound()) { return Range::FromMinExtent(analyzer.Simplify(s_int->min_value), analyzer.Simplify(s_int->max_value + 1 - s_int->min_value)); @@ -761,13 +761,13 @@ Range IntSet::CoverRange(Range max_range) const { PrimExpr IntSet::min() const { const IntervalSetNode* s_int = (*this).as(); - ICHECK(s_int); + TVM_FFI_ICHECK(s_int); return s_int->min_value; } PrimExpr IntSet::max() const { const IntervalSetNode* s_int = (*this).as(); - ICHECK(s_int); + TVM_FFI_ICHECK(s_int); return s_int->max_value; } @@ -850,7 +850,7 @@ SignType IntSet::GetSignType() const { } PrimExpr IntSet::PointValue() const { const IntervalSetNode* s_int = (*this).as(); - ICHECK(s_int && s_int->IsSinglePoint()); + TVM_FFI_ICHECK(s_int && s_int->IsSinglePoint()); return s_int->min_value; } @@ -1116,7 +1116,7 @@ static ffi::Optional EvalIterSum(const IterSumExpr& iter_min, const Prim if (iter_min->args.empty()) { return IntSet::FromMinExtent(iter_min->base, extent); } - ICHECK_EQ(iter_min->args.size(), 1) << "The `EvalIterSum` expects fused iter sum expr"; + TVM_FFI_ICHECK_EQ(iter_min->args.size(), 1) << "The `EvalIterSum` expects fused iter sum expr"; const IterSplitExpr& split = iter_min->args[0]; if (analyzer->CanProve(split->extent == 0)) { return IntSet::Nothing(); @@ -1165,7 +1165,7 @@ ffi::Optional> EstimateRegionStrictBound(const ffi::Array result; result.reserve(ndim); for (int i = 0; i < ndim; ++i) { @@ -1205,14 +1205,14 @@ ffi::Array EstimateRegionUpperBound(const ffi::Array& region, /*indices=*/{range->min}, /*input_iters=*/var_dom, /*predicate=*/predicate, /*check_level=*/IterMapLevel::Surjective, analyzer); if (!res->indices.empty()) { - ICHECK_EQ(res->indices.size(), 1U); + TVM_FFI_ICHECK_EQ(res->indices.size(), 1U); IterSumExpr sum_expr = res->indices[0]; // dynamic extent is not supported yet. PrimExpr extent = range->extent; if (!is_const_number(extent)) { IntSet relaxed = EvalSet(extent, AsIntSet(var_dom)); - ICHECK(relaxed.HasUpperBound()); + TVM_FFI_ICHECK(relaxed.HasUpperBound()); extent = relaxed.max(); } diff --git a/src/arith/ir_mutator_with_analyzer.cc b/src/arith/ir_mutator_with_analyzer.cc index 8fb6dba8764a..e57e4c0d5aa2 100644 --- a/src/arith/ir_mutator_with_analyzer.cc +++ b/src/arith/ir_mutator_with_analyzer.cc @@ -134,7 +134,7 @@ Stmt IRMutatorWithAnalyzer::VisitStmt_(const IfThenElseNode* op) { Stmt IRMutatorWithAnalyzer::VisitStmt_(const AttrStmtNode* op) { if (op->attr_key == tir::attr::thread_extent || op->attr_key == tir::attr::virtual_thread) { IterVar iv = Downcast(op->node); - ICHECK_NE(iv->thread_tag.length(), 0U); + TVM_FFI_ICHECK_NE(iv->thread_tag.length(), 0U); Range dom = Range::FromMinExtent(make_zero(op->value.dtype()), op->value); analyzer_->Bind(iv->var, dom); iter_vars_.Set(iv->var, dom); diff --git a/src/arith/ir_visitor_with_analyzer.cc b/src/arith/ir_visitor_with_analyzer.cc index 88eff9fc2c42..fada12e9c425 100644 --- a/src/arith/ir_visitor_with_analyzer.cc +++ b/src/arith/ir_visitor_with_analyzer.cc @@ -67,7 +67,7 @@ void IRVisitorWithAnalyzer::VisitStmt_(const IfThenElseNode* op) { void IRVisitorWithAnalyzer::VisitStmt_(const AttrStmtNode* op) { if (op->attr_key == tir::attr::thread_extent || op->attr_key == tir::attr::virtual_thread) { IterVar iv = Downcast(op->node); - ICHECK_NE(iv->thread_tag.length(), 0U); + TVM_FFI_ICHECK_NE(iv->thread_tag.length(), 0U); analyzer_.Bind(iv->var, Range::FromMinExtent(IntImm(op->value->dtype, 0), op->value)); } StmtExprVisitor::VisitStmt_(op); diff --git a/src/arith/iter_affine_map.cc b/src/arith/iter_affine_map.cc index 3de431fb9574..4d08c790724e 100644 --- a/src/arith/iter_affine_map.cc +++ b/src/arith/iter_affine_map.cc @@ -376,7 +376,7 @@ class IterMapRewriter : public ExprMutator { * It is not an error for IterMapRewriter to receive an expression that * cannot be represented as an IterSumExpr. In these cases, * IterMapRewriter returns the unrepresentable portions of the TIR graph - * without modification. As a result, the usual ICHECK or LOG(FATAL) + * without modification. As a result, the usual ICHECK or TVM_FFI_THROW(InternalError) * macros cannot be used. Instead, ErrorLogger(this) can be used to * report an unrepresentable TIR graph, which may be used in error * messages at the calling scope. @@ -687,18 +687,18 @@ class IterMapRewriter : public ExprMutator { predicate_induced_max = predicate_induced_max.value() - base; } ffi::Optional opt = TryFuseIters(expr, check_level_, false); - ICHECK(!opt.defined() || opt.value()->args.size() == 1); + TVM_FFI_ICHECK(!opt.defined() || opt.value()->args.size() == 1); // scale should be 1 if (opt.defined() && is_one(opt.value()->args[0]->scale)) { const IterSplitExpr split = opt.value()->args[0]; IterSumExpr structured_form = Downcast(split->source->source); // get the flattened form auto it = flattened_map_.find(structured_form); - ICHECK(it != flattened_map_.end()); + TVM_FFI_ICHECK(it != flattened_map_.end()); IterSumExpr flattened_form = it->second; // get the mark and offset of the structured_form auto it_mark = sum_fuse_map_.find(flattened_form); - ICHECK(it_mark != sum_fuse_map_.end()); + TVM_FFI_ICHECK(it_mark != sum_fuse_map_.end()); IterMark mark = it_mark->second.mark; PrimExpr mark_offset = it_mark->second.offset; PrimExpr iter_min = mark_offset; @@ -842,7 +842,7 @@ class IterMapRewriter : public ExprMutator { } else if (auto op = expr.as()) { return IterSumExpr({op.value()}, make_zero(expr->dtype)); } else { - ICHECK(!expr->IsInstance()); + TVM_FFI_ICHECK(!expr->IsInstance()); return IterSumExpr({}, expr); } } @@ -1053,7 +1053,7 @@ class IterMapRewriter : public ExprMutator { // - result->extent = lhs->extent * rhs->extent // Find base index, must have a candidate to make progress int matched_index = FindBaseIter(expr, visited, expr->args[rend]->source, rend); - ICHECK_NE(matched_index, -1); + TVM_FFI_ICHECK_NE(matched_index, -1); visited[matched_index] = true; IterSplitExpr rhs_iter = expr->args[matched_index]; @@ -1064,7 +1064,7 @@ class IterMapRewriter : public ExprMutator { first_possible_unit_extent_pos); if (matched_index == -1) break; IterSplitExpr lhs_iter = expr->args[matched_index]; - ICHECK(rhs_iter->source.same_as(lhs_iter->source)); + TVM_FFI_ICHECK(rhs_iter->source.same_as(lhs_iter->source)); PrimExpr lhs_lower_factor = MulAndNormalize(rhs_iter->lower_factor, rhs_iter->extent); if (!analyzer_->CanProveEqual(lhs_iter->lower_factor, lhs_lower_factor)) break; // all patterns match @@ -1141,7 +1141,7 @@ class IterMapRewriter : public ExprMutator { if (matched_pos == -1) { return std::nullopt; } - ICHECK(matched_scale.defined()); + TVM_FFI_ICHECK(matched_scale.defined()); // look for the longest constrained iter started from expr->args[j] // Example: expr = i*9 + j*2 + k, i in [0, 4) j in [0, 5) k in [0, 2) // predicate: j*2 + k < 9 @@ -1179,7 +1179,7 @@ class IterMapRewriter : public ExprMutator { flattened_iters.push_back(expr->args[k]); } auto iter = sum_fuse_map_.find(constraint_to_match.value()); - ICHECK(iter != sum_fuse_map_.end()); + TVM_FFI_ICHECK(iter != sum_fuse_map_.end()); const IterMarkWithOffset& iter_matched = iter->second; grouped_iters.emplace_back(iter_matched.mark, floordiv(matched_scale, base_scale)); expected_extra_base += iter_matched.offset * matched_scale; @@ -1539,7 +1539,7 @@ TVM_FFI_STATIC_INIT_BLOCK() { IterSumExpr NormalizeToIterSum(PrimExpr index, const ffi::Map& input_iters, arith::Analyzer* analyzer) { IterMapResult result; - ICHECK(IterRangeSanityCheck(input_iters)) + TVM_FFI_ICHECK(IterRangeSanityCheck(input_iters)) << "Invalid iterators. Iterators may not be expressions of each other."; // we skip constraint check as the most important thing here is only the pattern @@ -1673,7 +1673,7 @@ PrimExpr IterMapRewriter::VisitExpr_(const MulNode* op) { return ret; } else { - ICHECK(a->IsInstance()); + TVM_FFI_ICHECK(a->IsInstance()); IterSplitExpr ret = Downcast(std::move(a)); ret.CopyOnWrite()->scale *= b; return ret; @@ -1698,10 +1698,10 @@ IterSumExpr IterMapRewriter::PreprocessDividend(IterMapExpr dividend, PrimExpr o return IterSumExpr(); } IterSumExpr fused = opt_fused.value(); - ICHECK_EQ(fused->args.size(), 1U); + TVM_FFI_ICHECK_EQ(fused->args.size(), 1U); return fused; } else { - LOG(FATAL) << "Unsupported subclass of IterMarkExpr"; + TVM_FFI_THROW(InternalError) << "Unsupported subclass of IterMarkExpr"; } } @@ -1798,10 +1798,10 @@ std::pair IterMapRewriter::PadDividendToDivisor(IterSpl } // check that padding factor is compatible with current split and divisor - ICHECK(CanProveDivisible(info.padding_factor, split->lower_factor)) + TVM_FFI_ICHECK(CanProveDivisible(info.padding_factor, split->lower_factor)) << "The padding factor " << info.padding_factor << " is not divisible by " << split->lower_factor << " for the split " << split; - ICHECK(CanProveDivisible(info.padding_factor, divisor)) + TVM_FFI_ICHECK(CanProveDivisible(info.padding_factor, divisor)) << "The padding factor " << info.padding_factor << " is not divisible by " << divisor << " for the split " << split; @@ -1980,7 +1980,7 @@ PrimExpr IterMapRewriter::VisitExpr_(const FloorDivNode* op) { if (!preprocessed.defined()) { return ffi::GetRef(op); } - ICHECK_EQ(preprocessed->args.size(), 1U); + TVM_FFI_ICHECK_EQ(preprocessed->args.size(), 1U); PrimExpr remainder = SplitFloorDivConst(preprocessed->args[0], preprocessed->base, b); if (!remainder.defined()) { return ffi::GetRef(op); @@ -2065,7 +2065,7 @@ PrimExpr IterMapRewriter::VisitExpr_(const FloorModNode* op) { return ffi::GetRef(op); } - ICHECK_EQ(preprocessed->args.size(), 1U); + TVM_FFI_ICHECK_EQ(preprocessed->args.size(), 1U); PrimExpr remainder = SplitFloorModConst(preprocessed->args[0], preprocessed->base, b); if (!remainder.defined()) { return ffi::GetRef(op); @@ -2281,7 +2281,7 @@ class SubspaceDivider { } else if (auto op = expr.as()) { return IterSplitExpr(IterMark(op.value(), extent)); } else { - LOG(FATAL) << "Unknown IterMapExpr type"; + TVM_FFI_THROW(InternalError) << "Unknown IterMapExpr type"; } } }; @@ -2543,7 +2543,7 @@ class InverseAffineIterMapTransformer { ffi::Map operator()(const ffi::Array& iter_map, const ffi::Array& outputs) { - ICHECK(iter_map.size() == outputs.size()); + TVM_FFI_ICHECK(iter_map.size() == outputs.size()); std::vector post_dfs_order = ReverseTopologyOrder(iter_map); // initialize back propagation accumulator @@ -2559,7 +2559,7 @@ class InverseAffineIterMapTransformer { if (node->IsInstance()) { Visit_(Downcast(ffi::GetRef(node))); } else { - ICHECK(node->IsInstance()); + TVM_FFI_ICHECK(node->IsInstance()); Visit_(Downcast(ffi::GetRef(node))); } } @@ -2573,7 +2573,7 @@ class InverseAffineIterMapTransformer { // Case 1: Propagate to the input node directly when the sum expression has only one components if (iter_map_expr->args.size() == 1) { const auto& source = iter_map_expr->args[0]; - ICHECK(analyzer_->CanProveEqual(abs(source->scale), 1)); + TVM_FFI_ICHECK(analyzer_->CanProveEqual(abs(source->scale), 1)); backprop_.Set(source, (backprop_.at(source) + input) * source->scale); return; } @@ -2612,7 +2612,7 @@ class InverseAffineIterMapTransformer { } } else { const auto* split_expr = expr.as(); - ICHECK(split_expr); + TVM_FFI_ICHECK(split_expr); if (auto source = split_expr->source->source.as()) { fvisit(source.value()); } @@ -2652,7 +2652,7 @@ class InverseAffineIterMapTransformer { } PrimExpr expected_scale = sum_expr->args.back()->scale; for (size_t i = sum_expr->args.size(); i > 0; i--) { - ICHECK(analyzer_->CanProveEqual(sum_expr->args[i - 1]->scale, expected_scale)); + TVM_FFI_ICHECK(analyzer_->CanProveEqual(sum_expr->args[i - 1]->scale, expected_scale)); expected_scale *= sum_expr->args[i - 1]->extent; } } diff --git a/src/arith/modular_set.cc b/src/arith/modular_set.cc index e69b8ad20e85..1c3233959da0 100644 --- a/src/arith/modular_set.cc +++ b/src/arith/modular_set.cc @@ -108,7 +108,7 @@ class ModularSetAnalyzer::Impl : public ExprFunctorsecond == info) + TVM_FFI_ICHECK(it->second == info) << "Trying to update var \'" << var << "\'" << " with a different const bound: " << "original=" << ModularSet(it->second.coeff, it->second.base) << ", new=" << info; @@ -184,7 +184,7 @@ class ModularSetAnalyzer::Impl : public ExprFunctor a x @@ -235,7 +235,7 @@ class ModularSetAnalyzer::Impl : public ExprFunctor 0 && (round_down || parent_->CanProveGreaterEqual(lhs, 0)))) { diff --git a/src/arith/narrow_predicate_expression.cc b/src/arith/narrow_predicate_expression.cc index d73364cf45ca..07337ee1e151 100644 --- a/src/arith/narrow_predicate_expression.cc +++ b/src/arith/narrow_predicate_expression.cc @@ -51,7 +51,7 @@ using namespace tir; class ExpressionNarrower : public tir::ExprMutator { public: static PrimExpr Apply(PrimExpr expr, ffi::Map free_parameters) { - ICHECK(expr.dtype().is_bool()) << "Expected boolean expression, but received " << expr; + TVM_FFI_ICHECK(expr.dtype().is_bool()) << "Expected boolean expression, but received " << expr; ExpressionNarrower mutator(free_parameters); return mutator(expr); } @@ -193,7 +193,7 @@ class ExpressionNarrower : public tir::ExprMutator { return Context::Minimize; default: - LOG(FATAL) << "Unhandled Context, all legal values should be handled"; + TVM_FFI_THROW(InternalError) << "Unhandled Context, all legal values should be handled"; } } diff --git a/src/arith/pattern_match.h b/src/arith/pattern_match.h index 7c498d7a9c90..626d0b9cbab5 100644 --- a/src/arith/pattern_match.h +++ b/src/arith/pattern_match.h @@ -49,10 +49,10 @@ * arith::PVar v; * // We can match integer and Var, both of which are * // special case container of Expr - * ICHECK((v * c).Match(tx * 3)); - * ICHECK_EQ(c.Eval()->value, 3); + * TVM_FFI_ICHECK((v * c).Match(tx * 3)); + * TVM_FFI_ICHECK_EQ(c.Eval()->value, 3); * // cannot match c to ty - * ICHECK(!(v * c).Match(tx * ty)); + * TVM_FFI_ICHECK(!(v * c).Match(tx * ty)); * * \endcode * @@ -221,7 +221,7 @@ class PVar : public Pattern> { } T Eval() const { - ICHECK(filled_); + TVM_FFI_ICHECK(filled_); return value_; } diff --git a/src/arith/presburger_set.cc b/src/arith/presburger_set.cc index f69761259683..4be7f8442a55 100644 --- a/src/arith/presburger_set.cc +++ b/src/arith/presburger_set.cc @@ -83,7 +83,8 @@ static void Update(const PrimExpr& constraint, PresburgerSetNode* intset) { } disjunct.addEquality(int_coeffs); } else { - LOG(FATAL) << "Unsupported constraint expression: " << entry->GetTypeKey(); + TVM_FFI_THROW(InternalError) + << "Unsupported constraint expression: " << entry->GetTypeKey(); } } intset->unionInPlace(disjunct); @@ -186,7 +187,7 @@ PrimExpr PresburgerSetNode::GenerateConstraint() const { } PresburgerSet Union(const ffi::Array& sets) { - CHECK_GT(sets.size(), 0); + TVM_FFI_ICHECK_GT(sets.size(), 0); if (sets.size() == 1) return sets[0]; auto relations = sets[0]->disjuncts; for (size_t i = 1; i < sets.size(); ++i) { @@ -198,13 +199,13 @@ PresburgerSet Union(const ffi::Array& sets) { } PresburgerSet Intersect(const ffi::Array& sets) { - CHECK_GT(sets.size(), 0); + TVM_FFI_ICHECK_GT(sets.size(), 0); if (sets.size() == 1) return sets[0]; auto relations = sets[0]->disjuncts; const auto& space = sets[0]->space; for (size_t i = 1; i < sets.size(); ++i) { - ICHECK(space.isCompatible(sets[i]->space)) << "Spaces should match"; + TVM_FFI_ICHECK(space.isCompatible(sets[i]->space)) << "Spaces should match"; for (const IntegerRelation& relA : sets[i]->disjuncts) { for (const IntegerRelation& relB : relations) { IntegerRelation intersection = relA.intersect(relB); @@ -262,7 +263,7 @@ IntSet EvalSet(const PrimExpr& e, const PresburgerSet& set) { TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { auto set = node.as(); - ICHECK(ret) << "Unknown type:" << node->GetTypeKey(); + TVM_FFI_ICHECK(ret) << "Unknown type:" << node->GetTypeKey(); p->stream << "{"; p->stream << set->GetVars() << ": "; p->stream << node.as()->GenerateConstraint(); diff --git a/src/arith/presburger_set.h b/src/arith/presburger_set.h index 2404f36428f6..cf624f757e5f 100644 --- a/src/arith/presburger_set.h +++ b/src/arith/presburger_set.h @@ -166,7 +166,9 @@ class PresburgerSet : public IntSet { * \param constraint The constraint to construct the set. * \return The created set. */ - TVM_DLL PresburgerSet(const PrimExpr& constraint) { LOG(FATAL) << "MLIR is not enabled!"; } + TVM_DLL PresburgerSet(const PrimExpr& constraint) { + TVM_FFI_THROW(InternalError) << "MLIR is not enabled!"; + } }; #endif // TVM_MLIR_VERSION /*! diff --git a/src/arith/rewrite_simplify.cc b/src/arith/rewrite_simplify.cc index 65b6e408e2cb..7ae2f09e3990 100644 --- a/src/arith/rewrite_simplify.cc +++ b/src/arith/rewrite_simplify.cc @@ -376,9 +376,10 @@ void RewriteSimplifier::Impl::Update(const Var& var, const PrimExpr& info, bool if (!can_override) { auto it = var_map_.find(var); if (it != var_map_.end()) { - ICHECK(ExprDeepEqual()(it->second, info)) << "Trying to update var \'" << var << "\'" - << " with a different value: " - << "original=" << it->second << ", new=" << info; + TVM_FFI_ICHECK(ExprDeepEqual()(it->second, info)) + << "Trying to update var \'" << var << "\'" + << " with a different value: " + << "original=" << it->second << ", new=" << info; } } var_map_[var] = info; @@ -523,7 +524,7 @@ std::function RewriteSimplifier::Impl::EnterConstraint(const PrimExpr& c stats_.constraints_entered++; size_t new_literal_size = literal_constraints_.size(); auto frecover = [old_literal_size, new_literal_size, this]() { - ICHECK_EQ(literal_constraints_.size(), new_literal_size); + TVM_FFI_ICHECK_EQ(literal_constraints_.size(), new_literal_size); literal_constraints_.resize(old_literal_size); }; return frecover; @@ -782,7 +783,7 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const DivNode* op) { if ((div(ramp(b1, c1, lanes), broadcast(c2, lanes))).Match(ret)) { int64_t c1val = c1.Eval()->value; int64_t c2val = c2.Eval()->value; - ICHECK(c2val != 0) << "division by zero"; + TVM_FFI_ICHECK(c2val != 0) << "division by zero"; if (c1val % c2val == 0) { return ramp(div(b1, c2), div(c1, c2), lanes).Eval(); } @@ -937,7 +938,7 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const ModNode* op) { if (truncmod(ramp(b1, c1, lanes), broadcast(c2, lanes)).Match(ret)) { int64_t c1val = c1.Eval()->value; int64_t c2val = c2.Eval()->value; - ICHECK(c2val != 0) << "division by zero"; + TVM_FFI_ICHECK(c2val != 0) << "division by zero"; if (c1val % c2val == 0) { return broadcast(truncmod(b1, c2), lanes).Eval(); } @@ -1025,7 +1026,7 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const FloorDivNode* op) { if (floordiv(ramp(b1, c1, lanes), broadcast(c2, lanes)).Match(ret)) { int64_t c1val = c1.Eval()->value; int64_t c2val = c2.Eval()->value; - ICHECK(c2val != 0) << "division by zero"; + TVM_FFI_ICHECK(c2val != 0) << "division by zero"; if (c1val % c2val == 0) { return ramp(floordiv(b1, c2), floordiv(c1, c2), lanes).Eval(); } @@ -1172,7 +1173,7 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const FloorModNode* op) { if (floormod(ramp(b1, c1, lanes), broadcast(c2, lanes)).Match(ret)) { int64_t c1val = c1.Eval()->value; int64_t c2val = c2.Eval()->value; - ICHECK(c2val != 0) << "division by zero"; + TVM_FFI_ICHECK(c2val != 0) << "division by zero"; if (c1val % c2val == 0) { return broadcast(floormod(b1, c2), lanes).Eval(); } @@ -1756,7 +1757,7 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const NENode* op) { PrimExpr RewriteSimplifier::Impl::VisitExpr_(const LENode* op) { PrimExpr ret = IRMutatorWithAnalyzer::VisitExpr_(op); op = ret.as(); - ICHECK(op); + TVM_FFI_ICHECK(op); if (auto const_res = TryConstFold(op->a, op->b)) return const_res.value(); if (auto match = TryMatchLiteralConstraint(ret)) return match.value(); @@ -2309,7 +2310,7 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const CallNode* op) { return IntImm(op->dtype, bits - i - 1); } } - LOG(FATAL) << "Should not reach here"; + TVM_FFI_THROW(InternalError) << "Should not reach here"; } } diff --git a/src/arith/rewrite_simplify.h b/src/arith/rewrite_simplify.h index e541970a2717..976e490ddfd5 100644 --- a/src/arith/rewrite_simplify.h +++ b/src/arith/rewrite_simplify.h @@ -142,7 +142,8 @@ class RewriteSimplifier::Impl : public IRMutatorWithAnalyzer { void RecordRewrite() { stats_.rewrites_performed++; - ICHECK(maximum_rewrite_steps_ <= 0 || stats_.rewrites_performed <= maximum_rewrite_steps_) + TVM_FFI_ICHECK(maximum_rewrite_steps_ <= 0 || + stats_.rewrites_performed <= maximum_rewrite_steps_) << "RewriteSimplifier exceeded maximum number of rewrites allowed (" << maximum_rewrite_steps_ << ")"; } diff --git a/src/arith/solve_linear_equation.cc b/src/arith/solve_linear_equation.cc index 0afdbb4a58e5..88c55576206d 100644 --- a/src/arith/solve_linear_equation.cc +++ b/src/arith/solve_linear_equation.cc @@ -45,8 +45,8 @@ void SmithNormalFormDiag(std::vector>* S, std::vectorempty() || V->empty()) return; size_t m = S->size(); size_t n = (*S)[0].size(); // n is # of variables - ICHECK_EQ(V->size(), n); - ICHECK_EQ((*V)[0].size(), n); + TVM_FFI_ICHECK_EQ(V->size(), n); + TVM_FFI_ICHECK_EQ((*V)[0].size(), n); for (size_t index = 0; index < std::min(m, n); ++index) { // Here A is partially diagonalized, that is A[i, j] is zero for all i, j @@ -472,7 +472,8 @@ TVM_FFI_STATIC_INIT_BLOCK() { opt_relations.value_or({})); *ret = SolveLinearEquations(problem); } else { - LOG(FATAL) << "arith.SolveLinearEquations expects 1 or 3 arguments, gets " << args.size(); + TVM_FFI_THROW(InternalError) + << "arith.SolveLinearEquations expects 1 or 3 arguments, gets " << args.size(); } }); } diff --git a/src/arith/solve_linear_inequality.cc b/src/arith/solve_linear_inequality.cc index a46f9e520176..3b8e96773ba5 100644 --- a/src/arith/solve_linear_inequality.cc +++ b/src/arith/solve_linear_inequality.cc @@ -220,7 +220,7 @@ PartialSolvedInequalities SolveLinearInequalities(const IntConstraints& system_t ffi::Map res_bounds; for (const Var& v : system_to_solve->variables) { - ICHECK(!res_bounds.count(v)) + TVM_FFI_ICHECK(!res_bounds.count(v)) << "Variable " << v << " appears more than one time in the `variables` which might be a bug"; @@ -388,7 +388,7 @@ IntConstraints SolveInequalitiesToRange(const IntConstraints& inequalities) { analyzer.Bind(vranges); const Var& var = *it; - ICHECK(solved_bounds.count(var)); + TVM_FFI_ICHECK(solved_bounds.count(var)); auto bnd = solved_bounds[var]; if (is_one(bnd->coef) && !bnd->equal.empty()) { // There is an equation of the form `v == expr`, so this variable can be completely removed. @@ -539,25 +539,25 @@ IntConstraintsTransform SolveInequalitiesDeskewRange(const IntConstraints& inequ TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() - .def_packed( - "arith.SolveInequalitiesAsCondition", - [](ffi::PackedArgs args, ffi::Any* ret) { - IntConstraints problem; - PartialSolvedInequalities ret_ineq; - if (args.size() == 1) { - problem = args[0].cast(); - ret_ineq = SolveLinearInequalities(problem); - } else if (args.size() == 3) { - problem = IntConstraints(args[0].cast>(), - args[1].cast>(), - args[2].cast>()); - ret_ineq = SolveLinearInequalities(problem); - } else { - LOG(FATAL) << "arith.SolveInequalitiesAsCondition expects 1 or 3 arguments, gets " - << args.size(); - } - *ret = AsConditions(problem->variables, ret_ineq.first, ret_ineq.second); - }) + .def_packed("arith.SolveInequalitiesAsCondition", + [](ffi::PackedArgs args, ffi::Any* ret) { + IntConstraints problem; + PartialSolvedInequalities ret_ineq; + if (args.size() == 1) { + problem = args[0].cast(); + ret_ineq = SolveLinearInequalities(problem); + } else if (args.size() == 3) { + problem = IntConstraints(args[0].cast>(), + args[1].cast>(), + args[2].cast>()); + ret_ineq = SolveLinearInequalities(problem); + } else { + TVM_FFI_THROW(InternalError) + << "arith.SolveInequalitiesAsCondition expects 1 or 3 arguments, gets " + << args.size(); + } + *ret = AsConditions(problem->variables, ret_ineq.first, ret_ineq.second); + }) .def_packed("arith.SolveInequalitiesToRange", [](ffi::PackedArgs args, ffi::Any* ret) { if (args.size() == 1) { @@ -568,8 +568,9 @@ TVM_FFI_STATIC_INIT_BLOCK() { args[2].cast>()); *ret = SolveInequalitiesToRange(problem); } else { - LOG(FATAL) << "arith.SolveInequalitiesToRange expects 1 or 3 arguments, gets " - << args.size(); + TVM_FFI_THROW(InternalError) + << "arith.SolveInequalitiesToRange expects 1 or 3 arguments, gets " + << args.size(); } }) .def_packed("arith.SolveInequalitiesDeskewRange", [](ffi::PackedArgs args, ffi::Any* ret) { @@ -581,8 +582,9 @@ TVM_FFI_STATIC_INIT_BLOCK() { args[2].cast>()); *ret = SolveInequalitiesDeskewRange(problem); } else { - LOG(FATAL) << "arith.SolveInequalitiesDeskewRange expects 1 or 3 arguments, gets " - << args.size(); + TVM_FFI_THROW(InternalError) + << "arith.SolveInequalitiesDeskewRange expects 1 or 3 arguments, gets " + << args.size(); } }); } diff --git a/src/arith/transitive_comparison_analyzer.cc b/src/arith/transitive_comparison_analyzer.cc index b4cd7b260ebb..3794ff150bb9 100644 --- a/src/arith/transitive_comparison_analyzer.cc +++ b/src/arith/transitive_comparison_analyzer.cc @@ -320,7 +320,7 @@ CompareResult Reverse(CompareResult res) { case CompareResult::kUnknown: return CompareResult::kUnknown; default: - LOG(FATAL) << "Invalid CompareResult: " << static_cast(res); + TVM_FFI_THROW(InternalError) << "Invalid CompareResult: " << static_cast(res); } } @@ -478,10 +478,10 @@ TransitiveComparisonAnalyzer::Impl::Comparison::Negated() const { bool TransitiveComparisonAnalyzer::Impl::Comparison::Implies( const TransitiveComparisonAnalyzer::Impl::Comparison& other) const { - ICHECK(lhs_ == other.lhs_); - ICHECK(rhs_ == other.rhs_); - ICHECK(IsNormalized()); - ICHECK(other.IsNormalized()); + TVM_FFI_ICHECK(lhs_ == other.lhs_); + TVM_FFI_ICHECK(rhs_ == other.rhs_); + TVM_FFI_ICHECK(IsNormalized()); + TVM_FFI_ICHECK(other.IsNormalized()); if (result_ == other.result_ && offset_ == other.offset_) { // if c1 == c2, x != y + c1 => x != y + c2 @@ -563,8 +563,8 @@ void TransitiveComparisonAnalyzer::Impl::Bind(const tir::Var& var, const Range& bool differs_from_previous = !expr_equal(range->min, (*it).second->min) || !expr_equal(range->extent, (*it).second->extent); if (differs_from_previous) { - ICHECK(allow_override) << "Binding of variable " << var << " as " << range - << " conflicts with previous binding as " << (*it).second; + TVM_FFI_ICHECK(allow_override) << "Binding of variable " << var << " as " << range + << " conflicts with previous binding as " << (*it).second; if (auto key = ExprToPreviousKey(var)) { knowns_.erase(std::remove_if(knowns_.begin(), knowns_.end(), [&](const auto& known) { return known.lhs_ == key.value(); }), @@ -594,7 +594,7 @@ std::function TransitiveComparisonAnalyzer::Impl::EnterConstraint(const size_t new_literal_size = scoped_knowns_.size(); auto frecover = [old_literal_size, new_literal_size, this]() { - ICHECK_EQ(scoped_knowns_.size(), new_literal_size); + TVM_FFI_ICHECK_EQ(scoped_knowns_.size(), new_literal_size); scoped_knowns_.erase(scoped_knowns_.begin() + old_literal_size, scoped_knowns_.end()); }; return frecover; @@ -667,7 +667,7 @@ TransitiveComparisonAnalyzer::Impl::CollectIndirectComparisons(Key lhs_key, Key auto output = DFSFromLHS(lhs_key, rhs_key); for (Comparison cmp : DFSFromLHS(rhs_key, lhs_key)) { auto opt_normalized = cmp.WithLHS(lhs_key); - ICHECK(opt_normalized.has_value()); + TVM_FFI_ICHECK(opt_normalized.has_value()); output.push_back(opt_normalized.value()); } return output; @@ -732,12 +732,12 @@ TransitiveComparisonAnalyzer::Impl::DFSFromLHS(Key lhs_key, Key rhs_key) const { to_visit.erase(to_visit.begin()); std::vector& prev_knowns_using_middle = compared_to_lhs.at(middle_key); - ICHECK(compared_to_lhs.count(middle_key)); + TVM_FFI_ICHECK(compared_to_lhs.count(middle_key)); std::vector new_knowns_using_lhs; auto attempt_transitive = [&](Comparison cmp) { - ICHECK(cmp.IsNormalized()); + TVM_FFI_ICHECK(cmp.IsNormalized()); Key right_key = cmp.rhs_; @@ -862,10 +862,11 @@ CompareResult TransitiveComparisonAnalyzer::Impl::MergeComparisons( case CompareResult::kGT: case CompareResult::kLT: - LOG(FATAL) << "Internal error, normalized comparisons should only include <= and >="; + TVM_FFI_THROW(InternalError) + << "Internal error, normalized comparisons should only include <= and >="; default: - LOG(FATAL) << "Invalid CompareResult: " << static_cast(cmp.result_); + TVM_FFI_THROW(InternalError) << "Invalid CompareResult: " << static_cast(cmp.result_); } } diff --git a/src/arith/unwrap_vector_expr.cc b/src/arith/unwrap_vector_expr.cc index c074eb5c935a..a73cf89f3671 100644 --- a/src/arith/unwrap_vector_expr.cc +++ b/src/arith/unwrap_vector_expr.cc @@ -62,7 +62,7 @@ class Scalarizer : public ExprMutator { } auto it = let_var_remap_.find(op->var.get()); - ICHECK(it == let_var_remap_.end()) << "Duplicate binding of variable " << op->var; + TVM_FFI_ICHECK(it == let_var_remap_.end()) << "Duplicate binding of variable " << op->var; Var new_var(op->var->name_hint + "_scalar", op->var.dtype().element_of()); let_var_remap_[op->var.get()] = new_var; diff --git a/src/contrib/msc/core/codegen/base_codegen.h b/src/contrib/msc/core/codegen/base_codegen.h index 80747ceed0f2..bb5b5be058c6 100644 --- a/src/contrib/msc/core/codegen/base_codegen.h +++ b/src/contrib/msc/core/codegen/base_codegen.h @@ -188,24 +188,24 @@ class BaseCodeGen { return 1; } if (node->scope.size() == scopes_.top().size()) { - ICHECK(ArrayUtils::CompareArrays(node->scope, scopes_.top())) + TVM_FFI_ICHECK(ArrayUtils::CompareArrays(node->scope, scopes_.top())) << "Scope mismatch, node " << node->scope << " compare to current " << scopes_.top(); return 0; } else if (node->scope.size() == scopes_.top().size() + 1) { - ICHECK(ArrayUtils::CompareArrays(node->scope, scopes_.top(), scopes_.top().size())) + TVM_FFI_ICHECK(ArrayUtils::CompareArrays(node->scope, scopes_.top(), scopes_.top().size())) << "Scope increase mismatch, node " << node->scope << " compare to current " << scopes_.top(); scopes_.push(node->scope); return 1; } else if (node->scope.size() == scopes_.top().size() - 1) { - ICHECK(ArrayUtils::CompareArrays(node->scope, scopes_.top(), node->scope.size())) + TVM_FFI_ICHECK(ArrayUtils::CompareArrays(node->scope, scopes_.top(), node->scope.size())) << "Scope decrease mismatch, node " << node->scope << " compare to current " << scopes_.top(); scopes_.pop(); return -1; } else { - LOG(FATAL) << "Unexpected node scope " << node->scope << " with current scope " - << scopes_.top(); + TVM_FFI_THROW(InternalError) + << "Unexpected node scope " << node->scope << " with current scope " << scopes_.top(); } } diff --git a/src/contrib/msc/core/codegen/code_stack.cc b/src/contrib/msc/core/codegen/code_stack.cc index e1b34f7d28b7..e31726f647af 100644 --- a/src/contrib/msc/core/codegen/code_stack.cc +++ b/src/contrib/msc/core/codegen/code_stack.cc @@ -28,7 +28,7 @@ namespace contrib { namespace msc { const ffi::Array BaseStack::GetDocs() const { - ICHECK(blocks_.size() == 1) << "Has incomplete blocks, please check"; + TVM_FFI_ICHECK(blocks_.size() == 1) << "Has incomplete blocks, please check"; return TopBlock(); } @@ -39,7 +39,7 @@ void BaseStack::Line(const ffi::String& line) { Line(IdDoc(line)); } void BaseStack::Comment(const ffi::String& comment, bool attach) { if (attach) { const auto& doc = TopDoc(); - ICHECK(doc->IsInstance()) << "Only stmt doc support attach comments"; + TVM_FFI_ICHECK(doc->IsInstance()) << "Only stmt doc support attach comments"; const auto& stmt = Downcast(doc); stmt->comment = comment; } else { @@ -85,7 +85,7 @@ void BaseStack::FuncDecorator(const ffi::String& decorator) { } void BaseStack::FuncStart() { - ICHECK(TopDoc()->IsInstance()) << "FunctionDoc is not saved"; + TVM_FFI_ICHECK(TopDoc()->IsInstance()) << "FunctionDoc is not saved"; BlockStart(); } @@ -108,7 +108,7 @@ void BaseStack::ClassDecorator(const ffi::String& decorator) { } void BaseStack::ClassStart() { - ICHECK(TopDoc()->IsInstance()) << "ClassDoc is not saved"; + TVM_FFI_ICHECK(TopDoc()->IsInstance()) << "ClassDoc is not saved"; BlockStart(); } @@ -144,7 +144,7 @@ void BaseStack::ConstructorArg(const ffi::String& arg, const ffi::String& annota } void BaseStack::ConstructorStart() { - ICHECK(TopDoc()->IsInstance()) << "ConstructorDoc is not saved"; + TVM_FFI_ICHECK(TopDoc()->IsInstance()) << "ConstructorDoc is not saved"; BlockStart(); } @@ -176,7 +176,7 @@ void BaseStack::LambdaRef(const ffi::String& ref) { } void BaseStack::LambdaStart() { - ICHECK(TopDoc()->IsInstance()) << "LambdaDoc is not saved"; + TVM_FFI_ICHECK(TopDoc()->IsInstance()) << "LambdaDoc is not saved"; BlockStart(); } @@ -240,11 +240,11 @@ void BaseStack::MethodCall(const ffi::String& callee, bool new_line) { const auto& v_callee = callee + (new_line ? DocSymbol::NextLine() : ""); FuncCall(v_callee, std::nullopt, Downcast(host)); } else if (const auto* a_node = host.as()) { - ICHECK(a_node->rhs.defined()) << "Can not find rhs for inplace host"; + TVM_FFI_ICHECK(a_node->rhs.defined()) << "Can not find rhs for inplace host"; FuncCall(callee, DeclareDoc(a_node->annotation, a_node->lhs, ffi::Array(), true), a_node->rhs); } else { - LOG(FATAL) << "Unexpected host type for inplace " << host->GetTypeKey(); + TVM_FFI_THROW(InternalError) << "Unexpected host type for inplace " << host->GetTypeKey(); } } @@ -265,12 +265,12 @@ void BaseStack::InplaceEnd() { CallArgBase(Downcast(last)); } else if (const auto* assign = last.as()) { const auto& call = Downcast(assign->rhs); - ICHECK(assign->lhs->IsInstance()) + TVM_FFI_ICHECK(assign->lhs->IsInstance()) << "assign lhs should be IdDoc, get " << assign->lhs->GetTypeKey(); const auto& key = Downcast(assign->lhs)->name; CallArgBase(call, key); } else { - LOG(FATAL) << "Unexpected last type for call arg " << last->GetTypeKey(); + TVM_FFI_THROW(InternalError) << "Unexpected last type for call arg " << last->GetTypeKey(); } } @@ -279,7 +279,7 @@ void BaseStack::PopNest(const ffi::String& key) { if (last->IsInstance()) { CallArgBase(Downcast(last), key); } else { - LOG(FATAL) << "Unexpected nest type " << last->GetTypeKey(); + TVM_FFI_THROW(InternalError) << "Unexpected nest type " << last->GetTypeKey(); } } @@ -299,11 +299,11 @@ void BaseStack::CallArgBase(const ExprDoc& value, const ffi::String& key) { kwargs_keys = call->kwargs_keys; kwargs_values = call->kwargs_values; } else { - LOG(FATAL) << "Unexpected last type for call arg " << last->GetTypeKey(); + TVM_FFI_THROW(InternalError) << "Unexpected last type for call arg " << last->GetTypeKey(); } // push args or kwargs if (key.size() == 0) { - ICHECK(kwargs_keys.size() == 0) << "kwargs followed by args " << value; + TVM_FFI_ICHECK(kwargs_keys.size() == 0) << "kwargs followed by args " << value; args.push_back(value); } else { kwargs_keys.push_back(key); @@ -317,7 +317,7 @@ void BaseStack::CallArgBase(const ExprDoc& value, const ffi::String& key) { const auto& new_call = CallDoc(call->callee, args, kwargs_keys, kwargs_values); PushDoc(AssignDoc(assign->lhs, new_call, assign->annotation)); } else { - LOG(FATAL) << "Unexpected last type for call arg " << last->GetTypeKey(); + TVM_FFI_THROW(InternalError) << "Unexpected last type for call arg " << last->GetTypeKey(); } } @@ -433,7 +433,7 @@ void BaseStack::ScopeEnd() { bool BaseStack::HasBlock() const { return blocks_.size() > 0; } const ffi::Array BaseStack::TopBlock() const { - ICHECK(HasBlock()) << "No block found"; + TVM_FFI_ICHECK(HasBlock()) << "No block found"; return blocks_.top(); } @@ -451,7 +451,7 @@ bool BaseStack::HasDoc() { } const Doc BaseStack::TopDoc() { - ICHECK(HasDoc()) << "No doc or block found"; + TVM_FFI_ICHECK(HasDoc()) << "No doc or block found"; return TopBlock().back(); } @@ -463,14 +463,14 @@ const Doc BaseStack::PopDoc() { template const TDoc BaseStack::PopCheckedDoc() { - ICHECK(HasDoc() && TopDoc()->IsInstance()) + TVM_FFI_ICHECK(HasDoc() && TopDoc()->IsInstance()) << "Last doc(" << TopDoc()->GetTypeKey() << ") is not expected type " << TDocNode::TypeIndex2Key(TDocNode::RuntimeTypeIndex()); return Downcast(PopDoc()); } void BaseStack::PushDoc(const Doc& doc) { - ICHECK(HasBlock()) << "No block found"; + TVM_FFI_ICHECK(HasBlock()) << "No block found"; blocks_.top().push_back(doc); } diff --git a/src/contrib/msc/core/codegen/codegen_json.h b/src/contrib/msc/core/codegen/codegen_json.h index 65fdfd0a0352..d64717449fda 100644 --- a/src/contrib/msc/core/codegen/codegen_json.h +++ b/src/contrib/msc/core/codegen/codegen_json.h @@ -75,7 +75,7 @@ class MSCJSONSerializer : public JSONSerializer { namespace json = ::tvm::ffi::json; MSCCompileConfig config; config.Load(json::Parse(options).cast()); - ICHECK(config.graph_json.size() > 0) << "graph_json is needed to init MSCGraph"; + TVM_FFI_ICHECK(config.graph_json.size() > 0) << "graph_json is needed to init MSCGraph"; graph_ = MSCGraph(config.graph_json); for (const auto& pair : config.options) { options_.Set(pair.first, pair.second); @@ -86,7 +86,7 @@ class MSCJSONSerializer : public JSONSerializer { std::vector VisitExpr_(const CallNode* call_node) final; const ffi::String GetOption(const ffi::String& key) { - ICHECK(options_.count(key)) << "Can not find option " << key; + TVM_FFI_ICHECK(options_.count(key)) << "Can not find option " << key; return options_[key]; } diff --git a/src/contrib/msc/core/codegen/cpp_codegen.h b/src/contrib/msc/core/codegen/cpp_codegen.h index 5043da3744a3..ee9fb490c8b3 100644 --- a/src/contrib/msc/core/codegen/cpp_codegen.h +++ b/src/contrib/msc/core/codegen/cpp_codegen.h @@ -163,7 +163,7 @@ class CppCodeGen : public BaseCodeGen { break; } } - ICHECK(tensor_ctx.count("tensor")) + TVM_FFI_ICHECK(tensor_ctx.count("tensor")) << "Can not find weight " << tensor << " from " << producer; } else { const auto& pair = this->graph()->FindProducerAndIdx(tensor); diff --git a/src/contrib/msc/core/ir/graph.cc b/src/contrib/msc/core/ir/graph.cc index ada1922f22e6..c8ad5309fcff 100644 --- a/src/contrib/msc/core/ir/graph.cc +++ b/src/contrib/msc/core/ir/graph.cc @@ -413,7 +413,7 @@ void MSCJointNode::FromJson(const JsonMSCJoint& j_joint, scope.push_back(s); } for (const auto& p_name : j_joint.parents) { - ICHECK(nodes.count(p_name)) << "Can not find parent " << p_name; + TVM_FFI_ICHECK(nodes.count(p_name)) << "Can not find parent " << p_name; parents.push_back(nodes[p_name]); } for (const auto& in_name : j_joint.inputs) { @@ -426,7 +426,7 @@ void MSCJointNode::FromJson(const JsonMSCJoint& j_joint, break; } } - ICHECK(p_idx >= 0) << "Can not find parent for " << in_name; + TVM_FFI_ICHECK(p_idx >= 0) << "Can not find parent for " << in_name; ffi::Array input{Integer(p_idx), Integer(std::stol(index_str))}; inputs.push_back(input); } @@ -476,7 +476,7 @@ const ffi::Array MSCJointNode::GetOutputs() const { } const MSCTensor MSCJointNode::WeightAt(const ffi::String& wtype) const { - ICHECK(weights.count(wtype)) << "Can not find " << wtype << " from weights"; + TVM_FFI_ICHECK(weights.count(wtype)) << "Can not find " << wtype << " from weights"; return weights[wtype]; } @@ -516,7 +516,7 @@ const std::pair MSCJointNode::ProducerAndIdxOf(const ffi::Stri return ProducerAndIdxOf(i); } } - LOG(FATAL) << "Can not find producer of " << name; + TVM_FFI_THROW(InternalError) << "Can not find producer of " << name; } const std::pair MSCJointNode::ProducerAndIdxOf(const MSCTensor& input) const { @@ -572,7 +572,7 @@ void MSCPrimNode::FromJson(const JsonMSCPrim& j_prim, attrs.Set(pair.first, pair.second); } for (const auto& p_name : j_prim.parents) { - ICHECK(prims.count(p_name)) << "Can not find parent " << p_name; + TVM_FFI_ICHECK(prims.count(p_name)) << "Can not find parent " << p_name; parents.push_back(prims[p_name]); } } @@ -660,7 +660,7 @@ void WeightJointNode::FromJson(const JsonWeightJoint& j_joint, attrs.Set(pair.first, pair.second); } for (const auto& p_name : j_joint.parents) { - ICHECK(nodes.count(p_name)) << "Can not find parent " << p_name; + TVM_FFI_ICHECK(nodes.count(p_name)) << "Can not find parent " << p_name; parents.push_back(nodes[p_name]); } } @@ -814,12 +814,12 @@ const ffi::String MSCGraphNode::ToPrototxt() const { } const MSCJoint MSCGraphNode::FindNode(const ffi::String& name) const { - ICHECK(nodes.count(name)) << "Can not find node " << name; + TVM_FFI_ICHECK(nodes.count(name)) << "Can not find node " << name; return Downcast(nodes[name]); } const MSCPrim MSCGraphNode::FindPrim(const ffi::String& name) const { - ICHECK(prims.count(name)) << "Can not find prim " << name; + TVM_FFI_ICHECK(prims.count(name)) << "Can not find prim " << name; return prims[name]; } @@ -890,7 +890,7 @@ const MSCTensor MSCGraphNode::FindTensor(const ffi::String& name) const { return pair.second; } } - LOG(FATAL) << "Can not find weight " << name << " from " << node; + TVM_FFI_THROW(InternalError) << "Can not find weight " << name << " from " << node; } const auto& pair = FindProducerAndIdx(name); return pair.first->OutputAt(pair.second); @@ -911,12 +911,14 @@ const MSCJoint MSCGraphNode::FindProducer(const MSCTensor& tensor) const { const std::pair MSCGraphNode::FindProducerAndIdx(const ffi::String& name) const { const ffi::String& tensor_name = tensor_alias.count(name) ? tensor_alias[name] : name; - ICHECK(!weight_holders.count(tensor_name)) << "Weight " << name << " has no producer with index"; + TVM_FFI_ICHECK(!weight_holders.count(tensor_name)) + << "Weight " << name << " has no producer with index"; ffi::String host, index; std::tie(host, index) = StringUtils::SplitOnce(tensor_name, ":"); if (index.size() == 0) { const auto& node = FindNode(host); - ICHECK(node->optype == "constant") << "Tensor without index should be constant, get " << node; + TVM_FFI_ICHECK(node->optype == "constant") + << "Tensor without index should be constant, get " << node; return std::make_pair(node, 0); } return std::make_pair(FindNode(host), std::stoi(index)); @@ -949,7 +951,7 @@ const ffi::Array MSCGraphNode::FindConsumers(const MSCTensor& tensor) const std::vector> MSCGraphNode::FindConsumersAndIndices( const ffi::String& name) const { const ffi::String& tensor_name = tensor_alias.count(name) ? tensor_alias[name] : name; - ICHECK(!weight_holders.count(tensor_name)) << "Weight has no index"; + TVM_FFI_ICHECK(!weight_holders.count(tensor_name)) << "Weight has no index"; std::vector> consumers; for (const auto& c : FindConsumers(name)) { bool find_tensor = false; @@ -960,7 +962,7 @@ const std::vector> MSCGraphNode::FindConsumersAndInd break; } } - ICHECK(find_tensor) << "Can not find tensor " << name << " from " << c; + TVM_FFI_ICHECK(find_tensor) << "Can not find tensor " << name << " from " << c; } return consumers; } @@ -1166,7 +1168,7 @@ void WeightGraphNode::Build(const MSCGraph& graph, } const WeightJoint WeightGraphNode::FindNode(const ffi::String& name) const { - ICHECK(nodes.count(name)) << "Can not find node " << name; + TVM_FFI_ICHECK(nodes.count(name)) << "Can not find node " << name; return Downcast(nodes[name]); } @@ -1196,7 +1198,7 @@ void WeightGraphNode::FromJson(const JsonWeightGraph& j_graph) { for (const auto& j_joint : j_graph.nodes) { const auto& node = Downcast(nodes[j_joint.name]); for (const auto& f_name : j_joint.friends) { - ICHECK(nodes.count(f_name)) << "Can not find friend " << f_name; + TVM_FFI_ICHECK(nodes.count(f_name)) << "Can not find friend " << f_name; node->friends.push_back(nodes[f_name]); } } @@ -1249,7 +1251,7 @@ MSCGraph PruneWeights(const MSCGraph& graph, // define inputs std::vector> inputs; for (const auto& input : node->GetInputs()) { - ICHECK(inputs_map.count(input->name)) << "Can not find input " << input; + TVM_FFI_ICHECK(inputs_map.count(input->name)) << "Can not find input " << input; inputs.push_back(inputs_map[input->name]); } // define outputs diff --git a/src/contrib/msc/core/ir/graph.h b/src/contrib/msc/core/ir/graph.h index d54a5f75f959..a649cadda0af 100644 --- a/src/contrib/msc/core/ir/graph.h +++ b/src/contrib/msc/core/ir/graph.h @@ -107,7 +107,7 @@ struct JsonMSCTensor { prims.push_back(std::string(elem.cast())); } } - ICHECK_EQ(bitmask, 1 | 2 | 4) << "name, dtype and shape should be given"; + TVM_FFI_ICHECK_EQ(bitmask, 1 | 2 | 4) << "name, dtype and shape should be given"; } }; @@ -233,7 +233,7 @@ struct JsonMSCJoint { weights[std::string(kv.first.cast())] = std::move(item); } } - ICHECK_EQ(bitmask, 1 | 2 | 4 | 8) << "index, name, optype and outputs should be given"; + TVM_FFI_ICHECK_EQ(bitmask, 1 | 2 | 4 | 8) << "index, name, optype and outputs should be given"; } }; @@ -300,7 +300,7 @@ struct JsonMSCPrim { std::string(kv.second.cast()); } } - ICHECK_EQ(bitmask, 1 | 2 | 4) << "index, name and optype should be given"; + TVM_FFI_ICHECK_EQ(bitmask, 1 | 2 | 4) << "index, name and optype should be given"; } }; @@ -395,7 +395,8 @@ struct JsonWeightJoint { std::string(kv.second.cast()); } } - ICHECK_EQ(bitmask, 1 | 2 | 4 | 8) << "index, name, weight_type and weight should be given"; + TVM_FFI_ICHECK_EQ(bitmask, 1 | 2 | 4 | 8) + << "index, name, weight_type and weight should be given"; } }; @@ -490,7 +491,7 @@ struct JsonMSCGraph { prims.push_back(std::move(item)); } } - ICHECK_EQ(bitmask, 1 | 2 | 4 | 8) << "name, inputs, outputs and nodes should be given"; + TVM_FFI_ICHECK_EQ(bitmask, 1 | 2 | 4 | 8) << "name, inputs, outputs and nodes should be given"; } }; @@ -533,7 +534,7 @@ struct JsonWeightGraph { } bitmask |= 2; } - ICHECK_EQ(bitmask, 1 | 2) << "name and nodes should be given"; + TVM_FFI_ICHECK_EQ(bitmask, 1 | 2) << "name and nodes should be given"; } }; @@ -667,13 +668,13 @@ class BaseJointNode : public Object { template const T GetTypeAttr(const ffi::String& key) const { T val; - ICHECK(GetAttr(key, &val)) << "Can not get attr " << key; + TVM_FFI_ICHECK(GetAttr(key, &val)) << "Can not get attr " << key; return val; } template const std::vector GetTypeArrayAttr(const ffi::String& key) const { std::vector val; - ICHECK(GetAttr(key, &val)) << "Can not get attr " << key; + TVM_FFI_ICHECK(GetAttr(key, &val)) << "Can not get attr " << key; return val; } diff --git a/src/contrib/msc/core/ir/graph_builder.cc b/src/contrib/msc/core/ir/graph_builder.cc index df7a1520ebfa..cc713192d68b 100644 --- a/src/contrib/msc/core/ir/graph_builder.cc +++ b/src/contrib/msc/core/ir/graph_builder.cc @@ -96,7 +96,8 @@ void FuncParamsFinder::VisitExpr_(const CallNode* call_node) { if (const auto* v_node = call_node->op.as()) { func = Downcast(ref_module_->Lookup(v_node->name_hint)); } else if (call_node->op->IsInstance()) { - ICHECK(local_funcs_.count(call_node->op)) << "Can not find local func " << call_node->op; + TVM_FFI_ICHECK(local_funcs_.count(call_node->op)) + << "Can not find local func " << call_node->op; func = local_funcs_[call_node->op]; } if (func.defined()) { @@ -122,7 +123,8 @@ void LayoutsFinder::VisitExpr_(const CallNode* call_node) { func = Downcast(ref_module_->Lookup(v_node->name_hint)); VisitExpr(func); } else if (call_node->op->IsInstance()) { - ICHECK(local_funcs_.count(call_node->op)) << "Can not find local func " << call_node->op; + TVM_FFI_ICHECK(local_funcs_.count(call_node->op)) + << "Can not find local func " << call_node->op; func = local_funcs_[call_node->op]; } if (func.defined()) { @@ -179,7 +181,7 @@ const MSCGraph GraphBuilder::Build(const Function& func) { } else { LOG_FATAL << "Unexpected tuple input " << f << "(" << f->GetTypeKey() << ")"; } - ICHECK(expr_tensor_map_.count(f)) << "Can not find func param from tuple " << f; + TVM_FFI_ICHECK(expr_tensor_map_.count(f)) << "Can not find func param from tuple " << f; for (const auto& name : expr_tensor_map_[f]) { tuple_names.push_back(name); } @@ -188,7 +190,7 @@ const MSCGraph GraphBuilder::Build(const Function& func) { } else { AddNode(p, std::nullopt, p->name_hint()); } - ICHECK(expr_tensor_map_.count(p)) << "Can not find func param " << p; + TVM_FFI_ICHECK(expr_tensor_map_.count(p)) << "Can not find func param " << p; for (const auto& name : expr_tensor_map_[p]) { if (!added_inputs.count(name)) { input_names.push_back(name); @@ -197,7 +199,7 @@ const MSCGraph GraphBuilder::Build(const Function& func) { } } VisitExpr(func); - ICHECK(expr_tensor_map_.count(func->body->body)) + TVM_FFI_ICHECK(expr_tensor_map_.count(func->body->body)) << "Can not find seqexpr body " << func->body->body; output_names = expr_tensor_map_[func->body->body]; // remove const nodes as weights @@ -286,7 +288,8 @@ const MSCJoint GraphBuilder::AddNode(const Expr& expr, const ffi::Optional const auto& func = Downcast(ref_module_->Lookup(v_node->name_hint)); std::tie(node_name, optype, layout) = ParseFunc(func); } else if (call_node->op->IsInstance()) { - ICHECK(target_funcs_.count(call_node->op)) << "Can not find target func: " << call_node->op; + TVM_FFI_ICHECK(target_funcs_.count(call_node->op)) + << "Can not find target func: " << call_node->op; std::tie(node_name, optype, layout) = ParseFunc(target_funcs_[call_node->op]); } else if (call_node->op->IsInstance()) { std::tie(node_name, optype, layout) = ParseFunc(Downcast(call_node->op)); @@ -300,18 +303,19 @@ const MSCJoint GraphBuilder::AddNode(const Expr& expr, const ffi::Optional if (optype == "tuple" && expr->IsInstance() && Downcast(expr)->op->IsInstance()) { const auto& call_node = Downcast(expr); - ICHECK(target_funcs_.count(call_node->op)) << "Can not find target func: " << call_node->op; + TVM_FFI_ICHECK(target_funcs_.count(call_node->op)) + << "Can not find target func: " << call_node->op; const auto& tuple_func = target_funcs_[call_node->op]; for (size_t i = 0; i < call_node->args.size(); i++) { expr_tensor_map_.Set(tuple_func->params[i], expr_tensor_map_[call_node->args[i]]); } VisitExpr(tuple_func); - ICHECK(expr_tensor_map_.count(tuple_func->body->body)) + TVM_FFI_ICHECK(expr_tensor_map_.count(tuple_func->body->body)) << "Can not find seqexpr body " << tuple_func->body->body; const auto& outputs = expr_tensor_map_[tuple_func->body->body]; const auto& ref_expr = binding_var.defined() ? binding_var.value() : expr; expr_tensor_map_.Set(ref_expr, outputs); - ICHECK(tensor_input_map_.count(outputs[0])) << "Can not find tensor " << outputs[0]; + TVM_FFI_ICHECK(tensor_input_map_.count(outputs[0])) << "Can not find tensor " << outputs[0]; return Downcast(tensor_input_map_[outputs[0]].first); } @@ -327,7 +331,7 @@ const MSCJoint GraphBuilder::AddNode(const Expr& expr, const ffi::Optional target_funcs_[op]->GetAttr>(msc_attr::kOpattrs); if (opattrs_opt.defined()) { const auto& opattrs = opattrs_opt.value(); - ICHECK_EQ(opattrs.size(), plugin->attrs.size()) + TVM_FFI_ICHECK_EQ(opattrs.size(), plugin->attrs.size()) << "opattrs " << opattrs << " size mismatch with " << plugin->attrs.size(); for (size_t i = 0; i < opattrs.size(); i++) { attrs.Set(plugin->attrs[i]->name, opattrs[i]); @@ -348,7 +352,8 @@ const MSCJoint GraphBuilder::AddNode(const Expr& expr, const ffi::Optional attrs = FuncAttrGetter().GetAttrs(func); } } else if (call_node->op->IsInstance()) { - ICHECK(target_funcs_.count(call_node->op)) << "Can not find target func: " << call_node->op; + TVM_FFI_ICHECK(target_funcs_.count(call_node->op)) + << "Can not find target func: " << call_node->op; attrs = FuncAttrGetter().GetAttrs(target_funcs_[call_node->op]); } else if (call_node->op->IsInstance()) { attrs = FuncAttrGetter().GetAttrs(call_node->op); @@ -372,7 +377,7 @@ const MSCJoint GraphBuilder::AddNode(const Expr& expr, const ffi::Optional const auto& call = Downcast(expr); ffi::Array values; if (call->op->IsInstance()) { - ICHECK(target_funcs_.count(call->op)) << "Can not find target func: " << call->op; + TVM_FFI_ICHECK(target_funcs_.count(call->op)) << "Can not find target func: " << call->op; values = FuncValueGetter().GetValues(target_funcs_[call->op]); } input_types = ExprUtils::GetInputTypes(optype, call->args.size() + values.size(), true); @@ -385,8 +390,9 @@ const MSCJoint GraphBuilder::AddNode(const Expr& expr, const ffi::Optional attrs.Set(input_types[i], StringUtils::ToString(s_node->values)); ignore_nodes_.insert(Downcast(arg)->name_hint()); } else if (const auto* s_node = arg.as()) { - ICHECK(input_types[i] != "input") << i << " th PrimValue of " << optype - << " should has special type, get " << input_types; + TVM_FFI_ICHECK(input_types[i] != "input") + << i << " th PrimValue of " << optype << " should has special type, get " + << input_types; attrs.Set(input_types[i], StringUtils::ToString(s_node->value)); } else if (input_types[i] != "input" && arg->IsInstance()) { attrs.Set(input_types[i], StringUtils::ToString(arg)); @@ -403,13 +409,13 @@ const MSCJoint GraphBuilder::AddNode(const Expr& expr, const ffi::Optional if (plugin.defined()) { const auto& call = Downcast(expr); if (call->args.size() == 1) { - ICHECK(expr_tensor_map_.count(call->args[0])) + TVM_FFI_ICHECK(expr_tensor_map_.count(call->args[0])) << "Can not find tuple plugin input " << call->args[0]; input_names = expr_tensor_map_[call->args[0]]; } else { const auto& args = GetPluginInputs(expr); for (size_t i = 0; i < plugin->inputs.size(); i++) { - ICHECK(expr_tensor_map_.count(args[i])) << "Can not find plugin input " << args[i]; + TVM_FFI_ICHECK(expr_tensor_map_.count(args[i])) << "Can not find plugin input " << args[i]; for (const auto& in_name : expr_tensor_map_[args[i]]) { input_names.push_back(in_name); } @@ -427,7 +433,7 @@ const MSCJoint GraphBuilder::AddNode(const Expr& expr, const ffi::Optional } else if (input_types[i] == "input" && arg->IsInstance()) { const auto* tuple_node = arg.as(); for (const auto& f : tuple_node->fields) { - ICHECK(expr_tensor_map_.count(f)) << "Can not find tuple field " << f; + TVM_FFI_ICHECK(expr_tensor_map_.count(f)) << "Can not find tuple field " << f; for (const auto& in_name : expr_tensor_map_[f]) { arg_names.push_back(in_name); } @@ -477,19 +483,19 @@ const MSCJoint GraphBuilder::AddNode(const Expr& expr, const ffi::Optional } } else if (const auto* tuple_node = expr.as()) { for (const auto& f : tuple_node->fields) { - ICHECK(expr_tensor_map_.count(f)) << "Can not find tuple field " << f; + TVM_FFI_ICHECK(expr_tensor_map_.count(f)) << "Can not find tuple field " << f; for (const auto& in_name : expr_tensor_map_[f]) { input_names.push_back(in_name); } } } else if (const auto* getitem_node = expr.as()) { - ICHECK(expr_tensor_map_.count(getitem_node->tuple)) + TVM_FFI_ICHECK(expr_tensor_map_.count(getitem_node->tuple)) << "Can not find tuple " << getitem_node->tuple; input_names = expr_tensor_map_[getitem_node->tuple]; } else if (optype == "constant") { const auto& t_info = Downcast(GetStructInfo(expr)); const auto& shape_opt = t_info->GetShape(); - ICHECK(shape_opt.defined()) << "Constant shape is not defined"; + TVM_FFI_ICHECK(shape_opt.defined()) << "Constant shape is not defined"; const auto& weight = MSCTensor(node_name, t_info->dtype, layout, ArrayUtils::Cast(shape_opt.value())); node_weights.Set("const", weight); @@ -516,7 +522,7 @@ const MSCJoint GraphBuilder::AddNode(const Expr& expr, const ffi::Optional // Build output tensor auto build_output = [this](const StructInfo& sinfo, const ffi::String& node_name, const ffi::String& layout) { - ICHECK(sinfo->IsInstance()) + TVM_FFI_ICHECK(sinfo->IsInstance()) << "sinfo should be TensorStructInfo, get " << sinfo->GetTypeKey(); const auto& t_info = Downcast(sinfo); const auto& shape = ArrayUtils::Cast(ExprUtils::GetShape(t_info)); @@ -549,7 +555,7 @@ const MSCJoint GraphBuilder::AddNode(const Expr& expr, const ffi::Optional if (layouts.size() == 0) { layouts = ffi::Array(num_output, ""); } - ICHECK_EQ(layouts.size(), num_output) + TVM_FFI_ICHECK_EQ(layouts.size(), num_output) << "Layouts " << layouts << " msimatch with output size " << num_output; if (sinfo->IsInstance()) { const auto& t_name = node_name + ":" + std::to_string(0); @@ -566,7 +572,8 @@ const MSCJoint GraphBuilder::AddNode(const Expr& expr, const ffi::Optional outputs.push_back(build_output(tuple_sinfo->fields[i], t_name, layouts[i])); } } else { - LOG(FATAL) << "Unexpected struct info (" << sinfo->GetTypeKey() << ")" << sinfo; + TVM_FFI_THROW(InternalError) << "Unexpected struct info (" << sinfo->GetTypeKey() << ")" + << sinfo; } // Build node @@ -749,21 +756,22 @@ void GraphBuilder::VisitBinding_(const VarBindingNode* binding, const TupleGetIt void GraphBuilder::VisitBinding_(const VarBindingNode* binding, const VarNode* val) { ExprVisitor::VisitBinding_(binding, val); const auto& output = ffi::GetRef(val); - ICHECK(expr_tensor_map_.count(output)) << "Can not find var " << output; + TVM_FFI_ICHECK(expr_tensor_map_.count(output)) << "Can not find var " << output; expr_tensor_map_.Set(binding->var, expr_tensor_map_[output]); } void GraphBuilder::VisitBinding_(const VarBindingNode* binding, const DataflowVarNode* val) { ExprVisitor::VisitBinding_(binding, val); const auto& output = ffi::GetRef(val); - ICHECK(expr_tensor_map_.count(output)) << "Can not find dataflow var " << output; + TVM_FFI_ICHECK(expr_tensor_map_.count(output)) << "Can not find dataflow var " << output; expr_tensor_map_.Set(binding->var, expr_tensor_map_[output]); } void GraphBuilder::VisitBinding_(const VarBindingNode* binding, const FunctionNode* val) { const auto& name_opt = val->GetAttr(relax::attr::kComposite); - ICHECK(name_opt.has_value()) << "Unexpected target func without composite"; - ICHECK(config_.target.size() > 0 && StringUtils::StartsWith(name_opt.value(), config_.target)) + TVM_FFI_ICHECK(name_opt.has_value()) << "Unexpected target func without composite"; + TVM_FFI_ICHECK(config_.target.size() > 0 && + StringUtils::StartsWith(name_opt.value(), config_.target)) << "Target should be given for target function"; target_funcs_.Set(binding->var, ffi::GetRef(val)); } @@ -806,9 +814,9 @@ void GraphBuilder::VisitPrimExpr(const PrimExpr& prim) { } ffi::Array GraphBuilder::GetPluginInputs(const Expr& expr) { - ICHECK(expr->IsInstance()) << "plugin expr should be call"; + TVM_FFI_ICHECK(expr->IsInstance()) << "plugin expr should be call"; const auto& call = Downcast(expr); - ICHECK(call->args[1]->IsInstance()) << "plugin argument 1 should be call"; + TVM_FFI_ICHECK(call->args[1]->IsInstance()) << "plugin argument 1 should be call"; return Downcast(call->args[1])->fields; } @@ -821,7 +829,7 @@ void WeightsExtractor::VisitExpr_(const ConstantNode* op) { const auto& name = SpanUtils::GetAttr(op->span, msc_attr::kName); const auto& layout = SpanUtils::GetAttr(op->span, msc_attr::kLayout); const auto& sinfo = GetStructInfo(ffi::GetRef(op)); - ICHECK(sinfo->IsInstance()) + TVM_FFI_ICHECK(sinfo->IsInstance()) << "Constant StrcutInfo should be TensorStructInfo"; const auto& t_info = Downcast(sinfo); const auto& opt_shape = t_info->GetShape(); diff --git a/src/contrib/msc/core/ir/graph_builder.h b/src/contrib/msc/core/ir/graph_builder.h index 88e90c6df24f..536623488a21 100644 --- a/src/contrib/msc/core/ir/graph_builder.h +++ b/src/contrib/msc/core/ir/graph_builder.h @@ -182,7 +182,7 @@ class AttrGetter { if (value.type_index() >= kTVMFFIStaticObjectBegin) { attrs_->Set(key, StringUtils::ToString(value.cast())); } else { - LOG(FATAL) << "Unsupported type: " << value.type_index(); + TVM_FFI_THROW(InternalError) << "Unsupported type: " << value.type_index(); } break; } diff --git a/src/contrib/msc/core/ir/plugin.h b/src/contrib/msc/core/ir/plugin.h index 74ce371125b8..2d2441353f18 100644 --- a/src/contrib/msc/core/ir/plugin.h +++ b/src/contrib/msc/core/ir/plugin.h @@ -73,7 +73,7 @@ struct JsonPluginAttr { if (auto it = obj.find(ffi::String("describe")); it != obj.end()) { describe = std::string((*it).second.cast()); } - ICHECK_EQ(bitmask, 1 | 2) << "name and type should be given for plugin attr"; + TVM_FFI_ICHECK_EQ(bitmask, 1 | 2) << "name and type should be given for plugin attr"; if (describe.size() == 0) { describe = "Plugin attribute " + name + "(" + type + ")"; } @@ -118,7 +118,7 @@ struct JsonPluginTensor { if (auto it = obj.find(ffi::String("describe")); it != obj.end()) { describe = std::string((*it).second.cast()); } - ICHECK_EQ(bitmask, 1) << "name should be given for plugin tensor"; + TVM_FFI_ICHECK_EQ(bitmask, 1) << "name should be given for plugin tensor"; if (describe.size() == 0) { describe = "Plugin tensor " + name + "(" + dtype + " on " + device + ")"; } @@ -163,7 +163,7 @@ struct JsonPluginExtern { if (auto it = obj.find(ffi::String("describe")); it != obj.end()) { describe = std::string((*it).second.cast()); } - ICHECK_EQ(bitmask, 1) << "name should be given for plugin extern"; + TVM_FFI_ICHECK_EQ(bitmask, 1) << "name should be given for plugin extern"; if (describe.size() == 0) { describe = "Plugin function " + name + "(from " + header + ")"; } @@ -335,16 +335,16 @@ struct JsonPlugin { std::string(kv.second.cast()); } } - ICHECK_EQ(bitmask, 1 | 2 | 4) << "name, inputs and outputs should be given for plugin"; + TVM_FFI_ICHECK_EQ(bitmask, 1 | 2 | 4) << "name, inputs and outputs should be given for plugin"; if (externs.size() > 0) { - ICHECK(externs.count("infer_output")) << "infer_output should be given as extern"; + TVM_FFI_ICHECK(externs.count("infer_output")) << "infer_output should be given as extern"; bool has_compute = false; for (const auto& pair : externs) { if (StringUtils::EndsWith(pair.first, "_compute")) { has_compute = true; } } - ICHECK(has_compute) << "No compute function found, please check"; + TVM_FFI_ICHECK(has_compute) << "No compute function found, please check"; } if (describe.size() == 0) { describe = "Plugin " + name + "(" + version + ")"; @@ -684,7 +684,7 @@ class PluginRegistry { */ const Plugin Get(const ffi::String& name) const { auto it = plugin_map_.find(name); - ICHECK(it != plugin_map_.end()) << "Can not find plugin " << name; + TVM_FFI_ICHECK(it != plugin_map_.end()) << "Can not find plugin " << name; return it->second; } diff --git a/src/contrib/msc/core/printer/cpp_printer.cc b/src/contrib/msc/core/printer/cpp_printer.cc index 8c2a512a6d86..54de66638c06 100644 --- a/src/contrib/msc/core/printer/cpp_printer.cc +++ b/src/contrib/msc/core/printer/cpp_printer.cc @@ -45,7 +45,7 @@ void CppPrinter::PrintTypedDoc(const LiteralDoc& doc) { } void CppPrinter::PrintTypedDoc(const IndexDoc& doc) { - ICHECK(doc->indices.size() == 1) << "CppPrinter only support 1 size indices"; + TVM_FFI_ICHECK(doc->indices.size() == 1) << "CppPrinter only support 1 size indices"; PrintDoc(doc->value, false); output_ << "["; PrintDoc(doc->indices[0], false); @@ -75,7 +75,7 @@ void CppPrinter::PrintTypedDoc(const CallDoc& doc) { PrintDoc(doc->callee, false); output_ << "("; PrintJoinedDocs(doc->args); - ICHECK_EQ(doc->kwargs_keys.size(), doc->kwargs_values.size()) + TVM_FFI_ICHECK_EQ(doc->kwargs_keys.size(), doc->kwargs_values.size()) << "CallDoc should have equal number of elements in kwargs_keys and kwargs_values."; if (doc->args.size() > 0 && doc->kwargs_keys.size() > 0) { output_ << ", "; @@ -87,7 +87,7 @@ void CppPrinter::PrintTypedDoc(const CallDoc& doc) { } void CppPrinter::PrintTypedDoc(const AssignDoc& doc) { - ICHECK(doc->lhs.defined()) << "lhs should be given for assign"; + TVM_FFI_ICHECK(doc->lhs.defined()) << "lhs should be given for assign"; if (doc->annotation.defined()) { if (!IsEmptyDoc(doc->annotation.value())) { PrintDoc(doc->annotation.value(), false); @@ -133,7 +133,7 @@ void CppPrinter::PrintTypedDoc(const ForDoc& doc) { MaybePrintComment(doc, true); if (doc->rhs->IsInstance()) { const auto& tuple = Downcast(doc->rhs); - ICHECK_EQ(tuple->elements.size(), 2) << "For with tuple should has 2 elements"; + TVM_FFI_ICHECK_EQ(tuple->elements.size(), 2) << "For with tuple should has 2 elements"; output_ << "for (size_t "; PrintDoc(doc->lhs, false); output_ << " = "; @@ -159,7 +159,7 @@ void CppPrinter::PrintTypedDoc(const ForDoc& doc) { void CppPrinter::PrintTypedDoc(const ScopeDoc& doc) { MaybePrintComment(doc, true); - ICHECK(doc->rhs.defined()) << "rhs should be given for scope"; + TVM_FFI_ICHECK(doc->rhs.defined()) << "rhs should be given for scope"; PrintDoc(doc->rhs, false); PrintIndentedBlock(doc->body); } @@ -167,7 +167,8 @@ void CppPrinter::PrintTypedDoc(const ScopeDoc& doc) { void CppPrinter::PrintTypedDoc(const FunctionDoc& doc) { MaybePrintComment(doc, true); for (const AssignDoc& arg_doc : doc->args) { - ICHECK(!arg_doc->comment.has_value()) << "Function arg cannot have comment attached to them."; + TVM_FFI_ICHECK(!arg_doc->comment.has_value()) + << "Function arg cannot have comment attached to them."; } if (doc->return_type.defined()) { if (!IsEmptyDoc(doc->return_type.value())) { @@ -273,7 +274,7 @@ void CppPrinter::PrintTypedDoc(const StructDoc& doc) { void CppPrinter::PrintTypedDoc(const ConstructorDoc& doc) { MaybePrintComment(doc, true); for (const AssignDoc& arg_doc : doc->args) { - ICHECK(!arg_doc->comment.has_value()) + TVM_FFI_ICHECK(!arg_doc->comment.has_value()) << "Constructor arg cannot have comment attached to them."; } PrintDoc(doc->name, false); @@ -294,7 +295,8 @@ void CppPrinter::PrintTypedDoc(const ConstructorDoc& doc) { void CppPrinter::PrintTypedDoc(const LambdaDoc& doc) { MaybePrintComment(doc, true); for (const AssignDoc& arg_doc : doc->args) { - ICHECK(!arg_doc->comment.has_value()) << "Function arg cannot have comment attached to them."; + TVM_FFI_ICHECK(!arg_doc->comment.has_value()) + << "Function arg cannot have comment attached to them."; } output_ << "auto "; PrintDoc(doc->name, false); @@ -317,7 +319,7 @@ void CppPrinter::PrintTypedDoc(const LambdaDoc& doc) { void CppPrinter::PrintTypedDoc(const SwitchDoc& doc) { MaybePrintComment(doc, true); - ICHECK_EQ(doc->predicates.size(), doc->branchs.size()) + TVM_FFI_ICHECK_EQ(doc->predicates.size(), doc->branchs.size()) << "predicates " << doc->predicates.size() << " mismatch with branchs " << doc->branchs.size(); for (size_t i = 0; i < doc->predicates.size(); i++) { diff --git a/src/contrib/msc/core/printer/cpp_printer.h b/src/contrib/msc/core/printer/cpp_printer.h index 62e205a7c749..fa55b13ddcb9 100644 --- a/src/contrib/msc/core/printer/cpp_printer.h +++ b/src/contrib/msc/core/printer/cpp_printer.h @@ -119,25 +119,25 @@ class CppPrinter : public MSCBasePrinter { /*! \brief Exit a endline scope*/ void ExitEndlineScope() { - ICHECK(endlines_.size() > 1) << "No endline scope found"; + TVM_FFI_ICHECK(endlines_.size() > 1) << "No endline scope found"; endlines_.pop_back(); } /*! \brief enable enbline*/ void EnableEndline() { - ICHECK(endlines_.size() > 0) << "No endline scope found"; + TVM_FFI_ICHECK(endlines_.size() > 0) << "No endline scope found"; endlines_[endlines_.size() - 1] = true; } /*! \brief disable enbline*/ void DisableEndline() { - ICHECK(endlines_.size() > 0) << "No endline scope found"; + TVM_FFI_ICHECK(endlines_.size() > 0) << "No endline scope found"; endlines_[endlines_.size() - 1] = false; } /*! \brief Print endline*/ void Endline() { - ICHECK(endlines_.size() > 0) << "No endline scope found"; + TVM_FFI_ICHECK(endlines_.size() > 0) << "No endline scope found"; if (endlines_[endlines_.size() - 1]) { output_ << ";"; } diff --git a/src/contrib/msc/core/printer/msc_base_printer.cc b/src/contrib/msc/core/printer/msc_base_printer.cc index 644692aa6b66..fa6fc378f24f 100644 --- a/src/contrib/msc/core/printer/msc_base_printer.cc +++ b/src/contrib/msc/core/printer/msc_base_printer.cc @@ -93,7 +93,7 @@ void MSCBasePrinter::PrintDoc(const Doc& doc, bool new_line) { } else if (auto doc_node = doc.as()) { PrintTypedDoc(doc_node.value()); } else { - LOG(FATAL) << "Do not know how to print " << doc->GetTypeKey(); + TVM_FFI_THROW(InternalError) << "Do not know how to print " << doc->GetTypeKey(); throw; } } @@ -116,7 +116,7 @@ void MSCBasePrinter::PrintTypedDoc(const LiteralDoc& doc) { } else if (auto opt_str = value.as()) { output_ << "\"" << tvm::support::StrEscape((*opt_str).data(), (*opt_str).size()) << "\""; } else { - LOG(FATAL) << "TypeError: Unsupported literal value type: " << value.GetTypeKey(); + TVM_FFI_THROW(TypeError) << "Unsupported literal value type: " << value.GetTypeKey(); } } diff --git a/src/contrib/msc/core/printer/msc_base_printer.h b/src/contrib/msc/core/printer/msc_base_printer.h index 96981f02b7be..048eb25f8c90 100644 --- a/src/contrib/msc/core/printer/msc_base_printer.h +++ b/src/contrib/msc/core/printer/msc_base_printer.h @@ -121,77 +121,109 @@ class MSCBasePrinter { virtual void PrintTypedDoc(const ExprStmtDoc& doc); /*! \brief Virtual method to print an IndexDoc*/ - virtual void PrintTypedDoc(const IndexDoc& doc) { LOG(FATAL) << "Index is not implemented"; } + virtual void PrintTypedDoc(const IndexDoc& doc) { + TVM_FFI_THROW(InternalError) << "Index is not implemented"; + } /*! \brief Virtual method to print a CallDoc*/ - virtual void PrintTypedDoc(const CallDoc& doc) { LOG(FATAL) << "Call is not implemented"; } + virtual void PrintTypedDoc(const CallDoc& doc) { + TVM_FFI_THROW(InternalError) << "Call is not implemented"; + } /*! \brief Virtual method to print an AttrAccessDoc*/ virtual void PrintTypedDoc(const AttrAccessDoc& doc) { - LOG(FATAL) << "AttrAccess is not implemented"; + TVM_FFI_THROW(InternalError) << "AttrAccess is not implemented"; } /*! \brief Virtual method to print a DictDoc*/ - virtual void PrintTypedDoc(const DictDoc& doc) { LOG(FATAL) << "Dict is not implemented"; } + virtual void PrintTypedDoc(const DictDoc& doc) { + TVM_FFI_THROW(InternalError) << "Dict is not implemented"; + } /*! \brief Virtual method to print a SliceDoc*/ - virtual void PrintTypedDoc(const SliceDoc& doc) { LOG(FATAL) << "Slice is not implemented"; } + virtual void PrintTypedDoc(const SliceDoc& doc) { + TVM_FFI_THROW(InternalError) << "Slice is not implemented"; + } /*! \brief Virtual method to print an AssignDoc*/ - virtual void PrintTypedDoc(const AssignDoc& doc) { LOG(FATAL) << "Assign is not implemented"; } + virtual void PrintTypedDoc(const AssignDoc& doc) { + TVM_FFI_THROW(InternalError) << "Assign is not implemented"; + } /*! \brief Virtual method to print an IfDoc*/ - virtual void PrintTypedDoc(const IfDoc& doc) { LOG(FATAL) << "If is not implemented"; } + virtual void PrintTypedDoc(const IfDoc& doc) { + TVM_FFI_THROW(InternalError) << "If is not implemented"; + } /*! \brief Virtual method to print a WhileDoc*/ - virtual void PrintTypedDoc(const WhileDoc& doc) { LOG(FATAL) << "While is not implemented"; } + virtual void PrintTypedDoc(const WhileDoc& doc) { + TVM_FFI_THROW(InternalError) << "While is not implemented"; + } /*! \brief Virtual method to print a ForDoc*/ - virtual void PrintTypedDoc(const ForDoc& doc) { LOG(FATAL) << "For is not implemented"; } + virtual void PrintTypedDoc(const ForDoc& doc) { + TVM_FFI_THROW(InternalError) << "For is not implemented"; + } /*! \brief Virtual method to print a ScopeDoc*/ - virtual void PrintTypedDoc(const ScopeDoc& doc) { LOG(FATAL) << "Scope is not implemented"; } + virtual void PrintTypedDoc(const ScopeDoc& doc) { + TVM_FFI_THROW(InternalError) << "Scope is not implemented"; + } /*! \brief Virtual method to print an AssertDoc*/ - virtual void PrintTypedDoc(const AssertDoc& doc) { LOG(FATAL) << "Assert is not implemented"; } + virtual void PrintTypedDoc(const AssertDoc& doc) { + TVM_FFI_THROW(InternalError) << "Assert is not implemented"; + } /*! \brief Virtual method to print a FunctionDoc*/ virtual void PrintTypedDoc(const FunctionDoc& doc) { - LOG(FATAL) << "Function is not implemented"; + TVM_FFI_THROW(InternalError) << "Function is not implemented"; } /*! \brief Virtual method to print a ClassDoc*/ - virtual void PrintTypedDoc(const ClassDoc& doc) { LOG(FATAL) << "Class is not implemented"; } + virtual void PrintTypedDoc(const ClassDoc& doc) { + TVM_FFI_THROW(InternalError) << "Class is not implemented"; + } /*! \brief Virtual method to print a CommentDoc*/ - virtual void PrintTypedDoc(const CommentDoc& doc) { LOG(FATAL) << "Comment is not implemented"; } + virtual void PrintTypedDoc(const CommentDoc& doc) { + TVM_FFI_THROW(InternalError) << "Comment is not implemented"; + } /*! \brief Virtual method to print a DeclareDoc*/ - virtual void PrintTypedDoc(const DeclareDoc& doc) { LOG(FATAL) << "Declare is not implemented"; } + virtual void PrintTypedDoc(const DeclareDoc& doc) { + TVM_FFI_THROW(InternalError) << "Declare is not implemented"; + } /*! \brief Virtual method to print a StrictListDoc*/ virtual void PrintTypedDoc(const StrictListDoc& doc) { - LOG(FATAL) << "StrictList is not implemented"; + TVM_FFI_THROW(InternalError) << "StrictList is not implemented"; } /*! \brief Virtual method to print a PointerDoc*/ virtual void PrintTypedDoc(const PointerDoc& doc) { - LOG(FATAL) << "PointerDoc is not implemented"; + TVM_FFI_THROW(InternalError) << "PointerDoc is not implemented"; } /*! \brief Virtual method to print a StructDoc*/ - virtual void PrintTypedDoc(const StructDoc& doc) { LOG(FATAL) << "StructDoc is not implemented"; } + virtual void PrintTypedDoc(const StructDoc& doc) { + TVM_FFI_THROW(InternalError) << "StructDoc is not implemented"; + } /*! \brief Virtual method to print a ConstructorDoc*/ virtual void PrintTypedDoc(const ConstructorDoc& doc) { - LOG(FATAL) << "ConstructorDoc is not implemented"; + TVM_FFI_THROW(InternalError) << "ConstructorDoc is not implemented"; } /*! \brief Virtual method to print a SwitchDoc*/ - virtual void PrintTypedDoc(const SwitchDoc& doc) { LOG(FATAL) << "SwitchDoc is not implemented"; } + virtual void PrintTypedDoc(const SwitchDoc& doc) { + TVM_FFI_THROW(InternalError) << "SwitchDoc is not implemented"; + } /*! \brief Virtual method to print a LambdaDoc*/ - virtual void PrintTypedDoc(const LambdaDoc& doc) { LOG(FATAL) << "LambdaDoc is not implemented"; } + virtual void PrintTypedDoc(const LambdaDoc& doc) { + TVM_FFI_THROW(InternalError) << "LambdaDoc is not implemented"; + } /*! \brief Print docs to joined doc */ template diff --git a/src/contrib/msc/core/printer/print_utils.cc b/src/contrib/msc/core/printer/print_utils.cc index 50d36df10bdb..e6ab2b28c152 100644 --- a/src/contrib/msc/core/printer/print_utils.cc +++ b/src/contrib/msc/core/printer/print_utils.cc @@ -95,7 +95,7 @@ const ffi::Array DocUtils::ToStmts(const ffi::Array& docs) { } else if (d->IsInstance()) { stmts.push_back(ExprStmtDoc(Downcast(d))); } else { - LOG(FATAL) << "Unecpected doc type " << d->GetTypeKey(); + TVM_FFI_THROW(InternalError) << "Unecpected doc type " << d->GetTypeKey(); } } return stmts; diff --git a/src/contrib/msc/core/printer/prototxt_printer.cc b/src/contrib/msc/core/printer/prototxt_printer.cc index ffaf035385f1..299712ce9adc 100644 --- a/src/contrib/msc/core/printer/prototxt_printer.cc +++ b/src/contrib/msc/core/printer/prototxt_printer.cc @@ -88,10 +88,10 @@ void PrototxtPrinter::AppendPair(const ffi::String& key, const ffi::Any& value) } void PrototxtPrinter::PrintTypedDoc(const DictDoc& doc) { - ICHECK_EQ(doc->keys.size(), doc->values.size()) + TVM_FFI_ICHECK_EQ(doc->keys.size(), doc->values.size()) << "DictDoc should have equal number of elements in keys and values."; for (size_t i = 0; i < doc->keys.size(); i++) { - ICHECK(doc->keys[i].as()) + TVM_FFI_ICHECK(doc->keys[i].as()) << "Prototxt key should be IdDoc, get " << doc->keys[i]->GetTypeKey(); PrintDoc(doc->keys[i]); if (doc->values[i].as()) { diff --git a/src/contrib/msc/core/printer/python_printer.cc b/src/contrib/msc/core/printer/python_printer.cc index eb087f7f40e6..3966d8b3e5fe 100644 --- a/src/contrib/msc/core/printer/python_printer.cc +++ b/src/contrib/msc/core/printer/python_printer.cc @@ -66,7 +66,7 @@ void PythonPrinter::PrintTypedDoc(const CallDoc& doc) { PrintDoc(doc->callee, false); output_ << "("; PrintJoinedDocs(doc->args); - ICHECK_EQ(doc->kwargs_keys.size(), doc->kwargs_values.size()) + TVM_FFI_ICHECK_EQ(doc->kwargs_keys.size(), doc->kwargs_values.size()) << "CallDoc should have equal number of elements in kwargs_keys and kwargs_values."; if (doc->args.size() > 0 && doc->kwargs_keys.size() > 0) { output_ << ", "; @@ -124,7 +124,7 @@ void PythonPrinter::PrintTypedDoc(const ForDoc& doc) { MaybePrintComment(doc, true); if (doc->rhs->IsInstance()) { const auto& tuple = Downcast(doc->rhs); - ICHECK_EQ(tuple->elements.size(), 2) << "For with tuple should has 2 elements"; + TVM_FFI_ICHECK_EQ(tuple->elements.size(), 2) << "For with tuple should has 2 elements"; output_ << "for "; PrintDoc(doc->lhs, false); output_ << " in range("; @@ -157,7 +157,8 @@ void PythonPrinter::PrintTypedDoc(const ScopeDoc& doc) { void PythonPrinter::PrintTypedDoc(const FunctionDoc& doc) { for (const AssignDoc& arg_doc : doc->args) { - ICHECK(!arg_doc->comment.has_value()) << "Function arg cannot have comment attached to them."; + TVM_FFI_ICHECK(!arg_doc->comment.has_value()) + << "Function arg cannot have comment attached to them."; } PrintDecorators(doc->decorators); @@ -212,7 +213,7 @@ void PythonPrinter::PrintTypedDoc(const StrictListDoc& doc) { void PythonPrinter::PrintTypedDoc(const SwitchDoc& doc) { MaybePrintComment(doc, true); - ICHECK_EQ(doc->predicates.size(), doc->branchs.size()) + TVM_FFI_ICHECK_EQ(doc->predicates.size(), doc->branchs.size()) << "predicates " << doc->predicates.size() << " mismatch with branchs " << doc->branchs.size(); for (size_t i = 0; i < doc->predicates.size(); i++) { diff --git a/src/contrib/msc/core/transform/bind_named_params.cc b/src/contrib/msc/core/transform/bind_named_params.cc index 992c514ad7ef..08cff58e68b6 100644 --- a/src/contrib/msc/core/transform/bind_named_params.cc +++ b/src/contrib/msc/core/transform/bind_named_params.cc @@ -36,8 +36,8 @@ using namespace tvm::contrib::msc; std::tuple, ffi::Map> NormalizeNamedBindings( const Function& func, const ffi::Map& untyped_params) { - ICHECK(func.defined()); - ICHECK(untyped_params.defined()); + TVM_FFI_ICHECK(func.defined()); + TVM_FFI_ICHECK(untyped_params.defined()); // Map from string to the variable(s) with that name. std::unordered_map> string_lookup; @@ -53,28 +53,28 @@ std::tuple, ffi::Map> NormalizeNamedBind if (auto opt_str = obj.as()) { std::string str = opt_str.value(); auto it = string_lookup.find(str); - CHECK(it != string_lookup.end()) + TVM_FFI_ICHECK(it != string_lookup.end()) << "Function does not have parameter with name \"" << str << "\". " << "Function parameters are named " << func->params.Map([](const auto& param) { return param->name_hint(); }); - CHECK_EQ(it->second.size(), 1) + TVM_FFI_ICHECK_EQ(it->second.size(), 1) << "Function contains multiple parameters with name \"" << str << "\". " << "The Relax variables " << it->second << " are all named \"" << str << "\""; auto var = it->second[0]; - CHECK(!relax_var_remap.count(var)) + TVM_FFI_ICHECK(!relax_var_remap.count(var)) << "Remap of variable " << var << " was defined multiple times"; return var; } else if (auto opt_var = obj.as()) { auto var = opt_var.value(); - CHECK(!relax_var_remap.count(var)) + TVM_FFI_ICHECK(!relax_var_remap.count(var)) << "Remap of variable " << var << " was defined multiple times"; - CHECK(var_set.count(var.get())) + TVM_FFI_ICHECK(var_set.count(var.get())) << "Function does not use Relax variable " << var << " as a parameter. " << "Function parameters are " << func->params; return var; } else { - LOG(FATAL) + TVM_FFI_THROW(InternalError) << "Expected bound parameter to be a relax::Var, " << " or a string that uniquely identifies a relax::Var param within the function. " << "However, received object " << obj << " of type " << obj.GetTypeKey(); @@ -87,7 +87,8 @@ std::tuple, ffi::Map> NormalizeNamedBind const auto& span = SpanUtils::CreateWithAttr(msc_attr::kName, key->name_hint()); return Constant(opt.value(), StructInfo(), span); } else { - LOG(FATAL) << "Cannot coerce object of type " << obj.GetTypeKey() << " into relax expression"; + TVM_FFI_THROW(InternalError) + << "Cannot coerce object of type " << obj.GetTypeKey() << " into relax expression"; } }; diff --git a/src/contrib/msc/core/transform/bind_shape.cc b/src/contrib/msc/core/transform/bind_shape.cc index c9963ba94e84..4a196e0501f3 100644 --- a/src/contrib/msc/core/transform/bind_shape.cc +++ b/src/contrib/msc/core/transform/bind_shape.cc @@ -70,7 +70,7 @@ class ShapeBinder : public ExprMutator { } } // update main - ICHECK(main_var.defined()) << "Can not find entry func " << entry_name_; + TVM_FFI_ICHECK(main_var.defined()) << "Can not find entry func " << entry_name_; const auto& new_func = Downcast(VisitExpr(mod_->Lookup(entry_name_))); builder_->UpdateFunction(main_var, new_func); return builder_->GetContextIRModule(); @@ -91,11 +91,11 @@ class ShapeBinder : public ExprMutator { if (new_args.size() == call_node->args.size()) { ExprMutator::VisitBinding_(binding, call_node); } else if (const auto* op_node = call_node->op.as()) { - ICHECK(op_node->name == "relax.reshape" || op_node->name == "relax.image.resize2d") + TVM_FFI_ICHECK(op_node->name == "relax.reshape" || op_node->name == "relax.image.resize2d") << "Expect ShapeExpr consumer as reshape or image.resize2d, get " << ffi::GetRef(call_node); const auto& opt_shape = Downcast(GetStructInfo(call_node->args[1]))->values; - ICHECK(opt_shape.defined()) << "Expected shape defined, get " << call_node->args[1]; + TVM_FFI_ICHECK(opt_shape.defined()) << "Expected shape defined, get " << call_node->args[1]; new_args.push_back(ShapeExpr(opt_shape.value())); const auto& new_call = Call(call_node->op, new_args, call_node->attrs, call_node->sinfo_args, call_node->span); @@ -104,7 +104,7 @@ class ShapeBinder : public ExprMutator { const auto& func_info = Downcast(gv_node->struct_info_); ffi::Array params_info; for (const auto& a : new_args) { - ICHECK(a->struct_info_.defined()) + TVM_FFI_ICHECK(a->struct_info_.defined()) << "Global func argument without defined struct info " << a; params_info.push_back(Downcast(a->struct_info_.value())); } diff --git a/src/contrib/msc/core/transform/fuse_tuple.cc b/src/contrib/msc/core/transform/fuse_tuple.cc index 6f2913ac9599..08d8a995f8cb 100644 --- a/src/contrib/msc/core/transform/fuse_tuple.cc +++ b/src/contrib/msc/core/transform/fuse_tuple.cc @@ -61,7 +61,7 @@ class TupleFuser : public ExprMutator { } } // update main - ICHECK(main_var.defined()) << "Can not find entry func " << entry_name_; + TVM_FFI_ICHECK(main_var.defined()) << "Can not find entry func " << entry_name_; const auto& new_func = Downcast(VisitExpr(mod_->Lookup(entry_name_))); builder_->UpdateFunction(main_var, new_func); return builder_->GetContextIRModule(); @@ -85,7 +85,7 @@ class TupleFuser : public ExprMutator { } const auto& func_call = AddFunc(arg, tuple_name); const auto& tuple_out = builder_->Emit(func_call); - ICHECK(target_funcs_.count(func_call->op)) + TVM_FFI_ICHECK(target_funcs_.count(func_call->op)) << "Can not find target func " << func_call->op; target_funcs_.Set(tuple_out, target_funcs_[func_call->op]); has_tuple_arg = true; @@ -162,7 +162,7 @@ class TupleFuser : public ExprMutator { ffi::String func_name; Span expr_span = expr->span; if (!expr_span.defined()) { - ICHECK(tuple_name.size() > 0) << "Missing tuple for " << expr; + TVM_FFI_ICHECK(tuple_name.size() > 0) << "Missing tuple for " << expr; expr_span = SpanUtils::CreateWithAttr(msc_attr::kName, tuple_name); } if (expr->IsInstance()) { @@ -209,7 +209,8 @@ class TupleFuser : public ExprMutator { void ReEmitFunc(const VarBindingNode* binding, const Expr& expr) { const auto& func_call = AddFunc(expr); ReEmitBinding(binding, builder_->Normalize(func_call)); - ICHECK(target_funcs_.count(func_call->op)) << "Can not find target func " << func_call->op; + TVM_FFI_ICHECK(target_funcs_.count(func_call->op)) + << "Can not find target func " << func_call->op; target_funcs_.Set(binding->var, target_funcs_[func_call->op]); } diff --git a/src/contrib/msc/core/transform/inline_params.cc b/src/contrib/msc/core/transform/inline_params.cc index 9c5eb7536564..14f8c7896649 100644 --- a/src/contrib/msc/core/transform/inline_params.cc +++ b/src/contrib/msc/core/transform/inline_params.cc @@ -64,7 +64,7 @@ class ParamsInliner : public ExprMutator { } if (struct_info->IsInstance()) { const auto& optype_opt = func->GetAttr(msc_attr::kOptype); - ICHECK(optype_opt.has_value()) + TVM_FFI_ICHECK(optype_opt.has_value()) << "Can not find attr " << msc_attr::kOptype << " form extern func"; extern_types_.Set(p, optype_opt.value()); continue; @@ -75,7 +75,8 @@ class ParamsInliner : public ExprMutator { if (i->IsInstance()) { new_fields.push_back(i); } else if (const auto& p_info = i.as()) { - ICHECK(p_info->value.defined()) << "PrimStructInfo with undefined prim value " << i; + TVM_FFI_ICHECK(p_info->value.defined()) + << "PrimStructInfo with undefined prim value " << i; attrs.push_back(StringUtils::ToString(p_info->value.value())); } } @@ -99,7 +100,7 @@ class ParamsInliner : public ExprMutator { } } // update main - ICHECK(main_var.defined()) << "Can not find entry func " << entry_name_; + TVM_FFI_ICHECK(main_var.defined()) << "Can not find entry func " << entry_name_; const auto& new_func = Downcast(VisitExpr(mod_->Lookup(entry_name_))); builder_->UpdateFunction(main_var, new_func); return builder_->GetContextIRModule(); @@ -111,14 +112,14 @@ class ParamsInliner : public ExprMutator { for (const auto& a : call_node->args) { auto struct_info = GetStructInfo(a); if (a->IsInstance() && struct_info->IsInstance()) { - ICHECK(extern_types_.count(a)) << "Can not find extern type of " << a; + TVM_FFI_ICHECK(extern_types_.count(a)) << "Can not find extern type of " << a; new_args.push_back(ExternFunc(extern_types_[a])); has_inline = true; } else if (call_node->op->IsInstance() && a->IsInstance()) { has_inline = true; } else if (a->IsInstance() && struct_info->IsInstance()) { const auto& shape_opt = Downcast(GetStructInfo(a))->values; - ICHECK(shape_opt.defined()) << "Expected shape defined, get " << a; + TVM_FFI_ICHECK(shape_opt.defined()) << "Expected shape defined, get " << a; new_args.push_back(ShapeExpr(shape_opt.value())); has_inline = true; } else if (call_node->op->IsInstance() && a->IsInstance()) { @@ -155,7 +156,7 @@ class ParamsInliner : public ExprMutator { const auto& func_info = Downcast(gv_node->struct_info_); ffi::Array params_info; for (const auto& a : new_args) { - ICHECK(a->struct_info_.defined()) + TVM_FFI_ICHECK(a->struct_info_.defined()) << "Global func argument without defined struct info " << a; params_info.push_back(Downcast(a->struct_info_.value())); } diff --git a/src/contrib/msc/core/transform/layout_utils.cc b/src/contrib/msc/core/transform/layout_utils.cc index a4f46dce7fe4..e5fdfabe4daa 100644 --- a/src/contrib/msc/core/transform/layout_utils.cc +++ b/src/contrib/msc/core/transform/layout_utils.cc @@ -40,7 +40,7 @@ NLayout LayoutUtils::InferNLayout(const Expr& expr, const VarLayoutMap& var_layo LayoutDecision LayoutUtils::InferLayoutDecision(const Expr& expr, const VarLayoutMap& var_layout_map) { const auto& nlayout = InferNLayout(expr, var_layout_map); - ICHECK(nlayout.IsLeaf()) << "Cannot get layout for " << expr; + TVM_FFI_ICHECK(nlayout.IsLeaf()) << "Cannot get layout for " << expr; return nlayout.LeafValue(); } @@ -52,7 +52,7 @@ LayoutDecision LayoutUtils::InferLayoutDecisionAt(const Expr& expr, return index == 0 ? nlayouts.LeafValue() : LayoutDecision(""); } const auto& nlayout = nlayouts.NestedArray()[0]; - ICHECK(nlayout.IsLeaf()) << "Cannot get output layout for " << expr; + TVM_FFI_ICHECK(nlayout.IsLeaf()) << "Cannot get output layout for " << expr; return nlayout.LeafValue(); } @@ -121,7 +121,7 @@ const NLayout LayoutUtils::GetNLayout(const Expr& expr) { const LayoutDecision LayoutUtils::GetLayoutDecision(const Expr& expr) { NLayout nlayout = GetNLayout(expr); - ICHECK(nlayout.IsLeaf()) << "Cannot get layout for " << expr; + TVM_FFI_ICHECK(nlayout.IsLeaf()) << "Cannot get layout for " << expr; return nlayout.LeafValue(); } @@ -154,7 +154,7 @@ const LayoutDecision LayoutUtils::ExpandLayout(const LayoutDecision& src_layout, std::vector axes = expand_axes; std::sort(std::begin(axes), std::end(axes)); std::string new_layout = src_layout.name(); - ICHECK_EQ(new_layout.size(), src_layout->layout.ndim()) + TVM_FFI_ICHECK_EQ(new_layout.size(), src_layout->layout.ndim()) << "Only support normal layout, get " << src_layout->layout; std::set used_axes; for (size_t i = 0; i < src_layout->layout.ndim(); i++) { diff --git a/src/contrib/msc/core/transform/set_byoc_attrs.cc b/src/contrib/msc/core/transform/set_byoc_attrs.cc index 16ce44cede16..c459483481c7 100644 --- a/src/contrib/msc/core/transform/set_byoc_attrs.cc +++ b/src/contrib/msc/core/transform/set_byoc_attrs.cc @@ -74,7 +74,7 @@ class ByocNameSetter : public ExprMutator { void VisitBinding_(const VarBindingNode* binding, const CallNode* val) final { ExprMutator::VisitBinding_(binding, val); if (val->op->IsInstance()) { - ICHECK(local_funcs_.count(val->op)) << "Can not find local func " << val->op; + TVM_FFI_ICHECK(local_funcs_.count(val->op)) << "Can not find local func " << val->op; const auto& name_opt = local_funcs_[val->op]->GetAttr(msc_attr::kUnique); if (name_opt.has_value()) { val->span = SpanUtils::SetAttr(val->span, "name", name_opt.value()); diff --git a/src/contrib/msc/core/transform/set_expr_layout.cc b/src/contrib/msc/core/transform/set_expr_layout.cc index 90dd47cb2d36..75273350afb4 100644 --- a/src/contrib/msc/core/transform/set_expr_layout.cc +++ b/src/contrib/msc/core/transform/set_expr_layout.cc @@ -251,7 +251,7 @@ InferLayoutOutput ForwardInferLayoutBinary( input_layouts.push_back(output->input_layouts[i]); } } else { - LOG(FATAL) << "Binary input should be tensor, get " << sinfo->GetTypeKey(); + TVM_FFI_THROW(InternalError) << "Binary input should be tensor, get " << sinfo->GetTypeKey(); } } return InferLayoutOutput(input_layouts, output->output_layouts, Attrs()); @@ -676,7 +676,7 @@ InferLayoutOutput BackwardInferLayoutBinary( input_layouts.push_back(output->input_layouts[i]); } } else { - LOG(FATAL) << "Binary input should be tensor, get " << sinfo->GetTypeKey(); + TVM_FFI_THROW(InternalError) << "Binary input should be tensor, get " << sinfo->GetTypeKey(); } } return InferLayoutOutput(input_layouts, output->output_layouts, Attrs()); @@ -1283,7 +1283,7 @@ class LayoutInfer : public ExprVisitor { SetExprLayout(ret, var_layout_map_[Downcast(b_node->body)]); } } else { - LOG(FATAL) << "Function body should be SeqExpr, get " << func->body; + TVM_FFI_THROW(InternalError) << "Function body should be SeqExpr, get " << func->body; } } @@ -1303,7 +1303,7 @@ class LayoutInfer : public ExprVisitor { SetExprLayout(b_node->body, param_layout); } } else { - LOG(FATAL) << "Caller body should be SeqExpr, get " << caller->body; + TVM_FFI_THROW(InternalError) << "Caller body should be SeqExpr, get " << caller->body; } } } @@ -1325,7 +1325,7 @@ class LayoutChecker : public ExprVisitor { void Check(const Expr& expr) { ExprVisitor::VisitExpr(expr); - ICHECK_EQ(missing_num_, 0) << "Some layout is missing"; + TVM_FFI_ICHECK_EQ(missing_num_, 0) << "Some layout is missing"; } void VisitExpr_(const CallNode* call) final { diff --git a/src/contrib/msc/core/utils.cc b/src/contrib/msc/core/utils.cc index b917bc47a24e..cf25392c8973 100644 --- a/src/contrib/msc/core/utils.cc +++ b/src/contrib/msc/core/utils.cc @@ -40,7 +40,7 @@ size_t CommonUtils::GetIndex(int index, size_t max_size) { } else { v_index = index; } - ICHECK_LT(v_index, max_size) << "Index " << index << " out of range " << max_size; + TVM_FFI_ICHECK_LT(v_index, max_size) << "Index " << index << " out of range " << max_size; return v_index; } @@ -57,8 +57,9 @@ int CommonUtils::CompareVersion(const std::vector& given_version, if (given_version.size() == 0 || target_version.size() == 0) { return 0; } - ICHECK_EQ(given_version.size(), 3) << "Version should be in format major,minor,patch"; - ICHECK_EQ(target_version.size(), 3) << "Target version should be in format major,minor,patch"; + TVM_FFI_ICHECK_EQ(given_version.size(), 3) << "Version should be in format major,minor,patch"; + TVM_FFI_ICHECK_EQ(target_version.size(), 3) + << "Target version should be in format major,minor,patch"; for (size_t i = 0; i < 3; i++) { if (given_version[i] > target_version[i]) { return 1; @@ -310,7 +311,7 @@ bool ArrayUtils::CompareArrays(const ffi::Array& left, return false; } size = left.size(); - ICHECK_GT(size, 0) << "Positive size should be given, get " << size; + TVM_FFI_ICHECK_GT(size, 0) << "Positive size should be given, get " << size; if (size > static_cast(left.size()) || size > static_cast(right.size())) { return false; } @@ -492,7 +493,7 @@ const ffi::Array ExprUtils::GetInputTypes(const ffi::String& optype input_types.push_back("input"); } } - ICHECK_EQ(input_types.size(), inputs_num) + TVM_FFI_ICHECK_EQ(input_types.size(), inputs_num) << "Optype " << optype << " get input types " << input_types << " and inputs_num " << inputs_num << " mismatch"; return input_types; diff --git a/src/contrib/msc/core/utils.h b/src/contrib/msc/core/utils.h index a0732d5848ac..de6294bb45be 100644 --- a/src/contrib/msc/core/utils.h +++ b/src/contrib/msc/core/utils.h @@ -361,7 +361,7 @@ class ExprUtils { return T(reinterpret_cast(array->data)[i]); } } - LOG(FATAL) << "Failed to get scalar from array " << array; + TVM_FFI_THROW(InternalError) << "Failed to get scalar from array " << array; } /*! diff --git a/src/contrib/msc/framework/tensorflow/codegen.cc b/src/contrib/msc/framework/tensorflow/codegen.cc index 30488fcc9af0..45b7bdbc341b 100644 --- a/src/contrib/msc/framework/tensorflow/codegen.cc +++ b/src/contrib/msc/framework/tensorflow/codegen.cc @@ -142,7 +142,8 @@ void TensorflowCodeGen::CodeGenInference() { const ffi::Array TensorflowCodeGen::GetOpCodes(const MSCJoint& node) { const auto& ops_map = GetTFV1OpCodes(); auto it = ops_map->find(node->optype); - ICHECK(it != ops_map->end()) << "Unsupported tensorflow op(" << node->optype << "): " << node; + TVM_FFI_ICHECK(it != ops_map->end()) + << "Unsupported tensorflow op(" << node->optype << "): " << node; it->second->Config(node, config(), prims()); try { return it->second->GetDocs(); diff --git a/src/contrib/msc/framework/tensorflow/tf_v1_opcode.cc b/src/contrib/msc/framework/tensorflow/tf_v1_opcode.cc index d47021d84da5..5a603454ae1a 100644 --- a/src/contrib/msc/framework/tensorflow/tf_v1_opcode.cc +++ b/src/contrib/msc/framework/tensorflow/tf_v1_opcode.cc @@ -45,7 +45,7 @@ const std::pair> TFV1OpCode::GetPadding( kernel_size.push_back(weight->DimAt("H")->value); kernel_size.push_back(weight->DimAt("W")->value); } else if (node()->optype == "nn.avg_pool2d" || node()->optype == "nn.max_pool2d") { - ICHECK(node()->GetAttr(kernel_key, &kernel_size)); + TVM_FFI_ICHECK(node()->GetAttr(kernel_key, &kernel_size)); } else { LOG_FATAL << "Unexpected padding node" << node(); } @@ -328,14 +328,15 @@ class TFV1PadCodeGen : public TFV1OpCode { } ffi::Array pad_width; const auto& attr_pad_width = node()->GetTypeArrayAttr("pad_width"); - ICHECK(attr_pad_width.size() % 2 == 0) << "pad_width should be multiple of 2, get " << node(); + TVM_FFI_ICHECK(attr_pad_width.size() % 2 == 0) + << "pad_width should be multiple of 2, get " << node(); for (size_t i = 0; i < attr_pad_width.size(); i += 2) { const ffi::String& cur_pad = "[" + std::to_string(attr_pad_width[i]) + ", " + std::to_string(attr_pad_width[i + 1]) + "]"; pad_width.push_back(cur_pad); } const auto& val_producer = node()->ProducerOf(1); - ICHECK(val_producer->optype == "constant" && val_producer->HasAttr("scalar")); + TVM_FFI_ICHECK(val_producer->optype == "constant" && val_producer->HasAttr("scalar")); stack_.op_call() .op_input_arg() .call_arg(DocUtils::ToList(pad_width), "paddings") diff --git a/src/contrib/msc/framework/tensorrt/codegen.cc b/src/contrib/msc/framework/tensorrt/codegen.cc index 1be8cf0836c9..8cf746f825aa 100644 --- a/src/contrib/msc/framework/tensorrt/codegen.cc +++ b/src/contrib/msc/framework/tensorrt/codegen.cc @@ -545,7 +545,8 @@ const ffi::String TensorRTCodeGen::ToDims(const ffi::Array& dims, bool const ffi::Array TensorRTCodeGen::GetOpCodes(const MSCJoint& node) { const auto& ops_map = GetTensorRTOpCodes(); auto it = ops_map->find(GetOpType(node)); - ICHECK(it != ops_map->end()) << "Unsupported tensorrt op(" << node->optype << "): " << node; + TVM_FFI_ICHECK(it != ops_map->end()) + << "Unsupported tensorrt op(" << node->optype << "): " << node; it->second->Config(node, config(), prims()); try { return it->second->GetDocs(); @@ -607,10 +608,10 @@ ffi::Array MSCTensorRTCompiler(ffi::Array functions, for (const auto& func : functions) { VLOG(1) << "MSC.TensorRT partition:" << std::endl << func; const auto& name_opt = func->GetAttr(msc_attr::kUnique); - ICHECK(name_opt.has_value()) << "Can not find " << msc_attr::kUnique << " from attrs"; + TVM_FFI_ICHECK(name_opt.has_value()) << "Can not find " << msc_attr::kUnique << " from attrs"; const auto& name = name_opt.value(); std::string func_name = GetExtSymbol(func); - ICHECK(target_option.count(name)) << "Can not find target option for " << name; + TVM_FFI_ICHECK(target_option.count(name)) << "Can not find target option for " << name; const auto& options = Downcast(target_option[name]); MSCJSONSerializer serializer(constant_names, options); serializer.serialize(func); diff --git a/src/contrib/msc/framework/tensorrt/codegen_utils.h b/src/contrib/msc/framework/tensorrt/codegen_utils.h index b06d92dcd82b..df68ec66ab1d 100644 --- a/src/contrib/msc/framework/tensorrt/codegen_utils.h +++ b/src/contrib/msc/framework/tensorrt/codegen_utils.h @@ -56,14 +56,14 @@ class TensorRTCodeGenHelper : public BaseCodeGenHelper { const ffi::String IdxOutputBase(const MSCJoint& node, const ffi::String& prefix = "", int idx = 0, const ffi::String& suffix = "", bool mark_exit = false) final { if (node->optype == "argmax" || node->optype == "argmin") { - ICHECK_EQ(idx, 0) << "argmax and argmin only has 1 output, get " << idx; + TVM_FFI_ICHECK_EQ(idx, 0) << "argmax and argmin only has 1 output, get " << idx; return IdxNodeBase(node, prefix, suffix) + "->getOutput(1)"; } if (node->optype == "tuple") { return IdxNodeBase(node, prefix, suffix) + "[" + std::to_string(idx) + "]"; } if (node->optype == "get_item") { - ICHECK_EQ(idx, 0) << "get item only has 1 output, get " << idx; + TVM_FFI_ICHECK_EQ(idx, 0) << "get item only has 1 output, get " << idx; return IdxNodeBase(node, prefix, suffix); } return IdxNodeBase(node, prefix, suffix) + "->getOutput(" + std::to_string(idx) + ")"; diff --git a/src/contrib/msc/framework/tensorrt/tensorrt_opcode.cc b/src/contrib/msc/framework/tensorrt/tensorrt_opcode.cc index 4fde2bf8bc2e..f2f5baaa8277 100644 --- a/src/contrib/msc/framework/tensorrt/tensorrt_opcode.cc +++ b/src/contrib/msc/framework/tensorrt/tensorrt_opcode.cc @@ -146,7 +146,7 @@ const size_t TensorRTOpCode::AttrToReduceAxis(const ffi::String& key, size_t ndi return ToReduceAxis(axes, ndim); } int axis; - ICHECK(node()->GetAttr(key, &axis)) << "Can not get axes from attribute key " << key; + TVM_FFI_ICHECK(node()->GetAttr(key, &axis)) << "Can not get axes from attribute key " << key; return ToReduceAxis(std::vector{axis}, ndim); } @@ -251,7 +251,7 @@ class TensorRTArgmaxminCodeGen : public TensorRTOpCode { protected: void CodeGenBuild() final { - ICHECK(node()->GetTypeAttr("keepdims")) << "Only support argsort with keepdims"; + TVM_FFI_ICHECK(node()->GetTypeAttr("keepdims")) << "Only support argsort with keepdims"; stack_.op_call() .op_input_arg() .call_arg("TopKOperation::k" + symbol_) @@ -300,7 +300,7 @@ class TensorRTConcatCodeGen : public TensorRTOpCode { protected: void CodeGenBuild() final { const auto& producer = node()->ProducerOf(0); - ICHECK(node()->parents.size() == 1 && producer->optype == "tuple") + TVM_FFI_ICHECK(node()->parents.size() == 1 && producer->optype == "tuple") << "Concat expect parent as tuple, get " << node(); stack_.op_call().call_arg(IdxNodeBase(producer)).call_arg(producer->inputs.size()); SetLayerByValue("Axis", AttrToAxis()); @@ -313,7 +313,7 @@ class TensorRTConstantCodeGen : public TensorRTOpCode { protected: void CodeGenBuild() final { - ICHECK(!node()->HasAttr("scalar")) << "Scalar constant is not supported"; + TVM_FFI_ICHECK(!node()->HasAttr("scalar")) << "Scalar constant is not supported"; stack_.op_call().call_arg(ToDims(node()->OutputAt(0)->shape)).op_weight_arg("const"); } }; @@ -438,7 +438,8 @@ class TensorRTPadCodeGen : public TensorRTOpCode { protected: void CodeGenBuild() final { const auto& pad_width = node()->GetTypeArrayAttr("pad_width"); - ICHECK(pad_width.size() % 2 == 0) << "pad_width should be multiple of 2, get " << node(); + TVM_FFI_ICHECK(pad_width.size() % 2 == 0) + << "pad_width should be multiple of 2, get " << node(); std::vector pre_padding{2, 0}, post_padding{2, 0}; const auto& input = node()->InputAt(0); for (size_t i = 0; i < input->Ndim(); i++) { @@ -715,7 +716,7 @@ class TensorRTPluginOpCodeGen : public TensorRTOpCode { protected: void CodeGenBuild() final { const auto& producer = node()->ParentAt(0); - ICHECK(producer->optype == "tuple") + TVM_FFI_ICHECK(producer->optype == "tuple") << "Only support tensorrt plugin with tuple, get " << producer; const auto& plugin = GetPlugin(node()->optype); diff --git a/src/contrib/msc/framework/tensorrt/transform_tensorrt.cc b/src/contrib/msc/framework/tensorrt/transform_tensorrt.cc index ef58e31972df..4b6beb1164ad 100644 --- a/src/contrib/msc/framework/tensorrt/transform_tensorrt.cc +++ b/src/contrib/msc/framework/tensorrt/transform_tensorrt.cc @@ -89,7 +89,7 @@ const ffi::Array BroadcastShape(const ffi::Array& src_shape, if (ArrayUtils::Broadcastable(tailing_shape, out_shape)) { return tailing_shape; } - ICHECK(ArrayUtils::Broadcastable(leading_shape, out_shape)) + TVM_FFI_ICHECK(ArrayUtils::Broadcastable(leading_shape, out_shape)) << "Only support elemwise ops with leading or tailing expand"; return leading_shape; } @@ -159,7 +159,7 @@ Expr RewriteArgmaxmin(BlockBuilder builder, const Var& var, const Call& src_call const auto& call = new_calls.count(src_call) ? new_calls[src_call] : src_call; const auto& out_dtype = ExprUtils::GetDataType(var); const auto* src_attrs = src_call->attrs.as(); - ICHECK(out_dtype == DataType::Int(32) || out_dtype == DataType::Int(64)) + TVM_FFI_ICHECK(out_dtype == DataType::Int(32) || out_dtype == DataType::Int(64)) << "Unexpected out dtype " << out_dtype; static const Op& topk_op = Op::Get("relax.topk"); auto topk_attrs = ffi::make_object(); diff --git a/src/contrib/msc/framework/torch/codegen.cc b/src/contrib/msc/framework/torch/codegen.cc index c81646f8b267..4607b0d94bef 100644 --- a/src/contrib/msc/framework/torch/codegen.cc +++ b/src/contrib/msc/framework/torch/codegen.cc @@ -143,7 +143,7 @@ void TorchCodeGen::CodeGenInference() { const ffi::Array TorchCodeGen::GetOpCodes(const MSCJoint& node) { const auto& ops_map = GetTorchOpCodes(); auto it = ops_map->find(GetOpType(node)); - ICHECK(it != ops_map->end()) << "Unsupported torch op(" << node->optype << "): " << node; + TVM_FFI_ICHECK(it != ops_map->end()) << "Unsupported torch op(" << node->optype << "): " << node; it->second->Config(node, config(), is_init_, prims()); try { return it->second->GetDocs(); diff --git a/src/contrib/msc/framework/torch/codegen_utils.h b/src/contrib/msc/framework/torch/codegen_utils.h index 8a76729dbc7a..5ddff5fc2164 100644 --- a/src/contrib/msc/framework/torch/codegen_utils.h +++ b/src/contrib/msc/framework/torch/codegen_utils.h @@ -42,7 +42,7 @@ class TorchCodeGenHelper : public BaseCodeGenHelper { const ffi::String IdxOutputBase(const MSCJoint& node, const ffi::String& prefix = "", int idx = 0, const ffi::String& suffix = "", bool mark_exit = false) final { if ((node->optype == "max" || node->optype == "min") && node->OutputAt(0)->Ndim() > 0) { - ICHECK(idx == 0) << "max and min op only support 1 outputs, get " << node; + TVM_FFI_ICHECK(idx == 0) << "max and min op only support 1 outputs, get " << node; return IdxNodeBase(node, prefix, suffix) + ".values"; } return BaseCodeGenHelper::IdxOutputBase(node, prefix, idx, suffix, mark_exit); diff --git a/src/contrib/msc/framework/torch/torch_opcode.cc b/src/contrib/msc/framework/torch/torch_opcode.cc index 8f649469855e..7641f4c443f5 100644 --- a/src/contrib/msc/framework/torch/torch_opcode.cc +++ b/src/contrib/msc/framework/torch/torch_opcode.cc @@ -52,10 +52,11 @@ void TorchOpCode::CodeGenForward() { stack_.op_call().op_inputs_arg(false); } const StrictListDoc TorchOpCode::GetPadding(const ffi::String& key) { std::vector padding, src_padding; - ICHECK(node()->GetAttr(key, &src_padding)); + TVM_FFI_ICHECK(node()->GetAttr(key, &src_padding)); if (node()->optype == "nn.conv1d" || node()->optype == "msc.conv1d_bias") { if (src_padding.size() == 2) { - ICHECK(src_padding[0] == src_padding[1]) << "Only accept symmetric padding, get " << node(); + TVM_FFI_ICHECK(src_padding[0] == src_padding[1]) + << "Only accept symmetric padding, get " << node(); padding.push_back(src_padding[0]); } else { LOG_FATAL << "nn.conv1d with unexpected padding " << node(); @@ -63,7 +64,7 @@ const StrictListDoc TorchOpCode::GetPadding(const ffi::String& key) { } else if (node()->optype == "nn.conv2d" || node()->optype == "msc.conv2d_bias" || node()->optype == "nn.avg_pool2d" || node()->optype == "nn.max_pool2d") { if (src_padding.size() == 4) { - ICHECK(src_padding[0] == src_padding[2] && src_padding[1] == src_padding[3]) + TVM_FFI_ICHECK(src_padding[0] == src_padding[2] && src_padding[1] == src_padding[3]) << "Only accept symmetric padding, get " << node(); padding.push_back(src_padding[0]); padding.push_back(src_padding[1]); @@ -161,7 +162,7 @@ class TorchBatchNormCodeGen : public TorchOpCode { protected: void CodeGenInit() final { - ICHECK(node()->GetTypeAttr("center") && node()->GetTypeAttr("scale")) + TVM_FFI_ICHECK(node()->GetTypeAttr("center") && node()->GetTypeAttr("scale")) << "Only support center and scale batchnorm, get " << node(); const auto& gamma = node()->WeightAt("gamma"); stack_.op_call().call_arg(gamma->DimAt(0), "num_features").op_arg("epsilon", "eps"); @@ -381,7 +382,7 @@ class TorchGroupNormCodeGen : public TorchOpCode { protected: void CodeGenInit() final { - ICHECK(node()->GetTypeAttr("center") && node()->GetTypeAttr("scale")) + TVM_FFI_ICHECK(node()->GetTypeAttr("center") && node()->GetTypeAttr("scale")) << "Only support center and scale batchnorm, get " << node(); int channel_axis = node()->GetTypeAttr("channel_axis"); stack_.op_call() @@ -396,7 +397,7 @@ class TorchLayerNormCodeGen : public TorchOpCode { protected: void CodeGenInit() final { - ICHECK(node()->GetTypeAttr("center") && node()->GetTypeAttr("scale")) + TVM_FFI_ICHECK(node()->GetTypeAttr("center") && node()->GetTypeAttr("scale")) << "Only support center and scale batchnorm, get " << node(); const auto& axes = CommonUtils::GetIndices(node()->GetTypeArrayAttr("axes"), node()->InputAt(0)->Ndim()); @@ -568,7 +569,7 @@ class TorchResize2dCodeGen : public TorchOpCode { if (method == "nearest_neighbor") { v_method = "nearest"; } else { - LOG(FATAL) << "Unexpected resize2d method " << method; + TVM_FFI_THROW(InternalError) << "Unexpected resize2d method " << method; } stack_.op_call().op_input_arg().op_list_arg("size").call_arg(DocUtils::ToStr(v_method), "mode"); diff --git a/src/contrib/msc/framework/tvm/codegen.cc b/src/contrib/msc/framework/tvm/codegen.cc index 29445ed7ccc3..ab571e51cc93 100644 --- a/src/contrib/msc/framework/tvm/codegen.cc +++ b/src/contrib/msc/framework/tvm/codegen.cc @@ -202,7 +202,7 @@ const ffi::String RelaxCodeGen::DescribePrim(const MSCPrim& prim) { const ffi::Array RelaxCodeGen::GetOpCodes(const MSCJoint& node) { const auto& ops_map = GetRelaxOpCodes(); auto it = ops_map->find(GetOpType(node)); - ICHECK(it != ops_map->end()) << "Unsupported relax op(" << node->optype << "): " << node; + TVM_FFI_ICHECK(it != ops_map->end()) << "Unsupported relax op(" << node->optype << "): " << node; it->second->Config(node, config(), prims()); try { return it->second->GetDocs(); diff --git a/src/contrib/msc/framework/tvm/relax_opcode.cc b/src/contrib/msc/framework/tvm/relax_opcode.cc index da2cdfba5914..846b09da8329 100644 --- a/src/contrib/msc/framework/tvm/relax_opcode.cc +++ b/src/contrib/msc/framework/tvm/relax_opcode.cc @@ -282,7 +282,7 @@ class RelaxConvCodeGen : public RelaxOpCode { if (use_bias_) { std::string out_layout_str; if (!node()->GetAttr("out_layout", &out_layout_str)) { - ICHECK(node()->GetAttr("data_layout", &out_layout_str)) + TVM_FFI_ICHECK(node()->GetAttr("data_layout", &out_layout_str)) << "out_layout or data_layout should be given, get " << node(); } const auto& out_layout = tir::Layout(out_layout_str); @@ -482,7 +482,8 @@ class RelaxPadCodeGen : public RelaxOpCode { void CodeGenBuild() final { ffi::Array pad_width; const auto& attr_pad_width = node()->GetTypeArrayAttr("pad_width"); - ICHECK(attr_pad_width.size() % 2 == 0) << "pad_width should be multiple of 2, get " << node(); + TVM_FFI_ICHECK(attr_pad_width.size() % 2 == 0) + << "pad_width should be multiple of 2, get " << node(); for (size_t i = 0; i < attr_pad_width.size(); i += 2) { const ffi::String& cur_pad = "[" + std::to_string(attr_pad_width[i]) + ", " + std::to_string(attr_pad_width[i + 1]) + "]"; diff --git a/src/contrib/msc/plugin/base_codegen.h b/src/contrib/msc/plugin/base_codegen.h index 45d8002438c1..f8d63360c76a 100644 --- a/src/contrib/msc/plugin/base_codegen.h +++ b/src/contrib/msc/plugin/base_codegen.h @@ -480,7 +480,7 @@ class BasePluginCodeGen { .func_start() .assign("info", "{}"); for (const auto& name : ListPluginNames()) { - ICHECK(this->config()->ops_info.count(name)) << "Can not find op info for " << name; + TVM_FFI_ICHECK(this->config()->ops_info.count(name)) << "Can not find op info for " << name; const auto& info = this->config()->ops_info[name]; this->stack_.assign(DocUtils::ToIndex("info", DocUtils::ToStr(name)), info); } diff --git a/src/contrib/msc/plugin/tensorrt_codegen.cc b/src/contrib/msc/plugin/tensorrt_codegen.cc index 890b9a6df7b3..c7fc27ea5bd1 100644 --- a/src/contrib/msc/plugin/tensorrt_codegen.cc +++ b/src/contrib/msc/plugin/tensorrt_codegen.cc @@ -818,7 +818,8 @@ void TensorRTPluginCodeGen::CodegenBufferInfer(const Plugin& plugin) { } void TensorRTPluginCodeGen::CodegenEnqueue(const Plugin& plugin, bool dynamic) { - ICHECK(plugin->externs.count("cuda_compute")) << "cuda_compute is needed fo TensorRT plugin"; + TVM_FFI_ICHECK(plugin->externs.count("cuda_compute")) + << "cuda_compute is needed fo TensorRT plugin"; auto prepare_tensor = [this, &dynamic](const PluginTensor& tensor, const ffi::Map& dtypes, size_t idx, const ffi::String& collect) { diff --git a/src/contrib/msc/plugin/torch_codegen.cc b/src/contrib/msc/plugin/torch_codegen.cc index d5a2b5353de4..f43fd1c1a6b3 100644 --- a/src/contrib/msc/plugin/torch_codegen.cc +++ b/src/contrib/msc/plugin/torch_codegen.cc @@ -147,7 +147,7 @@ void TorchPluginCodeGen::CodeGenOpDefine(const Plugin& plugin) { .inplace_end() .for_end(); // malloc outputs and buffers - ICHECK(plugin->externs.count("infer_output")) << "Can not find extern shape"; + TVM_FFI_ICHECK(plugin->externs.count("infer_output")) << "Can not find extern shape"; CodeGenMalloc(plugin, plugin->outputs, "output"); if (plugin->externs.count("infer_buffer")) { CodeGenMalloc(plugin, plugin->buffers, "buffer"); diff --git a/src/contrib/msc/plugin/tvm_codegen.cc b/src/contrib/msc/plugin/tvm_codegen.cc index 44f690334b95..f33d45e16eed 100644 --- a/src/contrib/msc/plugin/tvm_codegen.cc +++ b/src/contrib/msc/plugin/tvm_codegen.cc @@ -226,7 +226,8 @@ void TVMPluginCodeGen::CodeGenOpDefine(const Plugin& plugin) { } void TVMPluginCodeGen::CodeGenOpRuntime(const Plugin& plugin) { - ICHECK(!plugin->externs.count("infer_buffer")) << "infer_buffer is not supported for tvm runtime"; + TVM_FFI_ICHECK(!plugin->externs.count("infer_buffer")) + << "infer_buffer is not supported for tvm runtime"; const auto& attr_name = MetaAttrCls(plugin); const auto& func_name = ComputeName(plugin); ffi::String device_cond = ""; @@ -380,7 +381,7 @@ void TVMPluginCodeGen::CodeGenCompute(const Plugin& plugin, const ffi::String& d const ffi::String& t_name = prepare_tensor(plugin->outputs[i], tensor_dtypes, i, "output"); compute_args.push_back(t_name); } - ICHECK(plugin->buffers.size() == 0) << "Plugin with buffers is not supported in tvm"; + TVM_FFI_ICHECK(plugin->buffers.size() == 0) << "Plugin with buffers is not supported in tvm"; compute_args.push_back("meta_attr"); if (device == "cuda") { // TODO(tvm-team): update to support get stream from device id diff --git a/src/ir/apply_pass_to_function.cc b/src/ir/apply_pass_to_function.cc index 3dd7c6a5ff8f..a3173e990ccc 100644 --- a/src/ir/apply_pass_to_function.cc +++ b/src/ir/apply_pass_to_function.cc @@ -95,7 +95,7 @@ Pass ApplyPassToFunction(Pass pass, ffi::String func_name_regex, } if (error_if_no_function_matches_regex) { - CHECK(at_least_one_function_matched_regex) + TVM_FFI_ICHECK(at_least_one_function_matched_regex) << "No function matched regex '" << func_name_regex << "', out of functions " << [&]() { ffi::Array function_names; for (const auto& [gvar, func] : mod->functions) { diff --git a/src/ir/diagnostic.cc b/src/ir/diagnostic.cc index 583549cfa4db..7885551eda90 100644 --- a/src/ir/diagnostic.cc +++ b/src/ir/diagnostic.cc @@ -52,6 +52,16 @@ Diagnostic::Diagnostic(DiagnosticLevel level, Span span, const std::string& mess data_ = std::move(n); } +Diagnostic::Diagnostic(DiagnosticLevel level, Span span, const std::string& message, + const std::string& error_kind) { + auto n = ffi::make_object(); + n->level = level; + n->span = span; + n->message = message; + n->error_kind = error_kind; + data_ = std::move(n); +} + DiagnosticBuilder Diagnostic::Bug(Span span) { return DiagnosticBuilder(DiagnosticLevel::kBug, span); } @@ -100,8 +110,24 @@ DiagnosticBuilder Diagnostic::Error(const Object* loc) { DiagnosticBuilder Diagnostic::Note(const Object* loc) { return Note(ffi::GetRef(loc)); } +DiagnosticBuilder Diagnostic::Warning(const Object* loc) { + return Warning(ffi::GetRef(loc)); +} + DiagnosticBuilder Diagnostic::Help(const Object* loc) { return Help(ffi::GetRef(loc)); } +DiagnosticBuilder Diagnostic::Error(std::string error_kind, Span span) { + return DiagnosticBuilder(DiagnosticLevel::kError, span).WithErrorKind(std::move(error_kind)); +} + +DiagnosticBuilder Diagnostic::Error(std::string error_kind, ObjectRef loc) { + return DiagnosticBuilder(DiagnosticLevel::kError, loc).WithErrorKind(std::move(error_kind)); +} + +DiagnosticBuilder Diagnostic::Error(std::string error_kind, const Object* loc) { + return Error(std::move(error_kind), ffi::GetRef(loc)); +} + /* Diagnostic Renderer */ void DiagnosticRenderer::Render(const DiagnosticContext& ctx) { (*this)->renderer(ctx); } @@ -138,8 +164,8 @@ void DiagnosticContext::Render() { if (errs) { (*this)->renderer = DiagnosticRenderer([](DiagnosticContext) {}); // (*this)->diagnostics.clear(); - LOG(FATAL) << "DiagnosticError: one or more error diagnostics were " - << "emitted, please check diagnostic render for output."; + TVM_FFI_THROW(DiagnosticError) << "one or more error diagnostics were " + << "emitted, please check diagnostic render for output."; } } @@ -151,7 +177,8 @@ TVM_FFI_STATIC_INIT_BLOCK() { } DiagnosticContext::DiagnosticContext(const IRModule& module, const DiagnosticRenderer& renderer) { - CHECK(renderer.defined()) << "can not initialize a diagnostic renderer with a null function"; + TVM_FFI_ICHECK(renderer.defined()) + << "can not initialize a diagnostic renderer with a null function"; auto n = ffi::make_object(); n->module = module; n->renderer = renderer; @@ -197,7 +224,7 @@ DiagnosticRenderer GetRenderer() { pf = tvm::ffi::TypedFunction(*override_pf); } else { auto default_pf = tvm::ffi::Function::GetGlobal(DEFAULT_RENDERER); - ICHECK(default_pf.has_value()) + TVM_FFI_ICHECK(default_pf.has_value()) << "Can not find registered function for " << DEFAULT_RENDERER << "." << std::endl << "Either this is an internal error or the default function was overloaded incorrectly."; pf = tvm::ffi::TypedFunction(*default_pf); @@ -267,13 +294,13 @@ void ReportAt(const DiagnosticContext& context, std::ostream& out, const Span& s return; } - ICHECK(context->module->source_map.defined()); + TVM_FFI_ICHECK(context->module->source_map.defined()); auto it = context->module->source_map->source_map.find(span->source_name); // If the source name is not in the current source map, sources were not annotated. if (it == context->module->source_map->source_map.end()) { - LOG(FATAL) << "The source maps are not populated for this module. " - << "Error: " << diagnostic->message; + TVM_FFI_THROW(InternalError) << "The source maps are not populated for this module. " + << "Error: " << diagnostic->message; } auto source = (*it).second; diff --git a/src/ir/env_func.cc b/src/ir/env_func.cc index 5a6e2c662b61..774d81fb32cf 100644 --- a/src/ir/env_func.cc +++ b/src/ir/env_func.cc @@ -41,7 +41,7 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) ObjectPtr CreateEnvNode(const std::string& name) { auto f = tvm::ffi::Function::GetGlobal(name); - ICHECK(f.has_value()) << "Cannot find global function \'" << name << '\''; + TVM_FFI_ICHECK(f.has_value()) << "Cannot find global function \'" << name << '\''; ObjectPtr n = ffi::make_object(); n->func = *f; n->name = name; @@ -57,7 +57,7 @@ TVM_FFI_STATIC_INIT_BLOCK() { .def_packed("ir.EnvFuncCall", [](ffi::PackedArgs args, ffi::Any* rv) { EnvFunc env = args[0].cast(); - ICHECK_GE(args.size(), 1); + TVM_FFI_ICHECK_GE(args.size(), 1); env->func.CallPacked(args.Slice(1), rv); }) .def("ir.EnvFuncGetFunction", [](const EnvFunc& n) { return n->func; }); diff --git a/src/ir/expr.cc b/src/ir/expr.cc index b856854a5d8f..cd1ac04abb47 100644 --- a/src/ir/expr.cc +++ b/src/ir/expr.cc @@ -51,26 +51,25 @@ PrimExpr::PrimExpr(float value) : PrimExpr(FloatImm(DataType::Float(32), value)) PrimExpr PrimExpr::ConvertFallbackValue(ffi::String value) { return tir::StringImm(value); } IntImm::IntImm(DataType dtype, int64_t value, Span span) { - ICHECK(dtype.is_scalar()) << "ValueError: IntImm can only take scalar, but " << dtype - << " was supplied."; - ICHECK(dtype.is_int() || dtype.is_uint() || dtype.is_bool()) - << "ValueError: IntImm supports only int or uint or bool type, but " << dtype - << " was supplied."; + TVM_FFI_CHECK(dtype.is_scalar(), ValueError) + << "IntImm can only take scalar, but " << dtype << " was supplied."; + TVM_FFI_CHECK(dtype.is_int() || dtype.is_uint() || dtype.is_bool(), ValueError) + << "IntImm supports only int or uint or bool type, but " << dtype << " was supplied."; if (dtype.is_uint()) { - ICHECK_GE(value, 0U) << "ValueError: Literal value " << value - << " is negative for unsigned integer type " << dtype; + TVM_FFI_CHECK_GE(value, 0U, ValueError) + << "Literal value " << value << " is negative for unsigned integer type " << dtype; if (dtype.bits() < 64) { - ICHECK_LT(value, 1LL << dtype.bits()) - << "ValueError: Literal value " << value << " exceeds maximum of " << dtype; + TVM_FFI_CHECK_LT(value, 1LL << dtype.bits(), ValueError) + << "Literal value " << value << " exceeds maximum of " << dtype; } } else if (dtype.bits() == 1 || dtype.is_bool()) { // int(1) - ICHECK(value == 0 || value == 1) << "ValueError: " << value << " exceeds range of " << dtype; + TVM_FFI_CHECK(value == 0 || value == 1, ValueError) << value << " exceeds range of " << dtype; } else if (dtype.bits() < 64) { - ICHECK_GE(value, -(1LL << (dtype.bits() - 1))) - << "ValueError: Literal value " << value << " exceeds minimum of " << dtype; - ICHECK_LT(value, 1LL << (dtype.bits() - 1)) - << "ValueError: Literal value " << value << " exceeds maximum of " << dtype; + TVM_FFI_CHECK_GE(value, -(1LL << (dtype.bits() - 1)), ValueError) + << "Literal value " << value << " exceeds minimum of " << dtype; + TVM_FFI_CHECK_LT(value, 1LL << (dtype.bits() - 1), ValueError) + << "Literal value " << value << " exceeds maximum of " << dtype; } ObjectPtr node = ffi::make_object(); node->dtype = dtype; @@ -87,29 +86,30 @@ TVM_FFI_STATIC_INIT_BLOCK() { } FloatImm::FloatImm(DataType dtype, double value, Span span) { - ICHECK_EQ(dtype.lanes(), 1) << "ValueError: FloatImm can only take scalar."; + TVM_FFI_CHECK_EQ(dtype.lanes(), 1, ValueError) << "FloatImm can only take scalar."; - ICHECK(dtype.is_float() || dtype.is_bfloat16() || dtype.is_float8() || dtype.is_float6() || - dtype.is_float4() || dtype.code() >= DataType::kCustomBegin) - << "ValueError: FloatImm supports only float, but " << dtype << " was supplied."; + TVM_FFI_CHECK(dtype.is_float() || dtype.is_bfloat16() || dtype.is_float8() || dtype.is_float6() || + dtype.is_float4() || dtype.code() >= DataType::kCustomBegin, + ValueError) + << "FloatImm supports only float, but " << dtype << " was supplied."; // check range for float32 and float16 since they have specified range. if (!std::isinf(value) && !std::isnan(value)) { if (dtype.bits() == 32) { - ICHECK_GE(value, std::numeric_limits::lowest()) - << "ValueError: Literal value " << value << " exceeds minimum of " << dtype; - ICHECK_LE(value, std::numeric_limits::max()) - << "ValueError: Literal value " << value << " exceeds maximum of " << dtype; + TVM_FFI_CHECK_GE(value, std::numeric_limits::lowest(), ValueError) + << "Literal value " << value << " exceeds minimum of " << dtype; + TVM_FFI_CHECK_LE(value, std::numeric_limits::max(), ValueError) + << "Literal value " << value << " exceeds maximum of " << dtype; } else if (dtype.is_float16()) { - ICHECK_GE(value, -support::kMaxFloat16) - << "ValueError: Literal value " << value << " exceeds minimum of " << dtype; - ICHECK_LE(value, support::kMaxFloat16) - << "ValueError: Literal value " << value << " exceeds maximum of " << dtype; + TVM_FFI_CHECK_GE(value, -support::kMaxFloat16, ValueError) + << "Literal value " << value << " exceeds minimum of " << dtype; + TVM_FFI_CHECK_LE(value, support::kMaxFloat16, ValueError) + << "Literal value " << value << " exceeds maximum of " << dtype; } else if (dtype.is_bfloat16()) { - ICHECK_GE(value, -support::kMaxBFloat16) - << "ValueError: Literal value " << value << " exceeds minimum of " << dtype; - ICHECK_LE(value, support::kMaxBFloat16) - << "ValueError: Literal value " << value << " exceeds maximum of " << dtype; + TVM_FFI_CHECK_GE(value, -support::kMaxBFloat16, ValueError) + << "Literal value " << value << " exceeds minimum of " << dtype; + TVM_FFI_CHECK_LE(value, support::kMaxBFloat16, ValueError) + << "Literal value " << value << " exceeds maximum of " << dtype; } else if (dtype.is_float8_e3m4() || dtype.is_float8_e4m3() || dtype.is_float8_e4m3b11fnuz() || dtype.is_float8_e4m3fn() || dtype.is_float8_e4m3fnuz() || dtype.is_float8_e5m2() || dtype.is_float8_e5m2fnuz() || dtype.is_float8_e8m0fnu()) { @@ -146,33 +146,33 @@ FloatImm::FloatImm(DataType dtype, double value, Span span) { nonneg = true; break; default: - LOG(FATAL) << "Unhandled float8 type: " << dtype; + TVM_FFI_THROW(InternalError) << "Unhandled float8 type: " << dtype; } if (nonneg) { - ICHECK_GE(value, 0) << "ValueError: Literal value " << value << " below zero for unsigned " - << dtype; + TVM_FFI_CHECK_GE(value, 0, ValueError) + << "Literal value " << value << " below zero for unsigned " << dtype; } else { - ICHECK_GE(value, -bound) << "ValueError: Literal value " << value << " below minimum of " - << dtype; + TVM_FFI_CHECK_GE(value, -bound, ValueError) + << "Literal value " << value << " below minimum of " << dtype; } - ICHECK_LE(value, bound) << "ValueError: Literal value " << value << " exceeds maximum of " - << dtype; + TVM_FFI_CHECK_LE(value, bound, ValueError) + << "Literal value " << value << " exceeds maximum of " << dtype; } else if (dtype.is_float6_e2m3fn() || dtype.is_float6_e3m2fn()) { double bound = (dtype.code() == DataType::TypeCode::kFloat6_e2m3fn) ? support::kMaxE2M3FN : support::kMaxE3M2FN; - ICHECK_GE(value, -bound) << "ValueError: Literal value " << value << " below minimum of " - << dtype; - ICHECK_LE(value, bound) << "ValueError: Literal value " << value << " exceeds maximum of " - << dtype; + TVM_FFI_CHECK_GE(value, -bound, ValueError) + << "Literal value " << value << " below minimum of " << dtype; + TVM_FFI_CHECK_LE(value, bound, ValueError) + << "Literal value " << value << " exceeds maximum of " << dtype; } else if (dtype.is_float4_e2m1fn()) { double bound = support::kMaxE2M1FN; - ICHECK_GE(value, -bound) << "ValueError: Literal value " << value << " below minimum of " - << dtype; - ICHECK_LE(value, bound) << "ValueError: Literal value " << value << " exceeds maximum of " - << dtype; + TVM_FFI_CHECK_GE(value, -bound, ValueError) + << "Literal value " << value << " below minimum of " << dtype; + TVM_FFI_CHECK_LE(value, bound, ValueError) + << "Literal value " << value << " exceeds maximum of " << dtype; } } ObjectPtr node = ffi::make_object(); diff --git a/src/ir/function.cc b/src/ir/function.cc index de14d57b3ef8..8a5da7dbefa9 100644 --- a/src/ir/function.cc +++ b/src/ir/function.cc @@ -45,7 +45,8 @@ TVM_FFI_STATIC_INIT_BLOCK() { } else if (func->IsInstance()) { return WithAttr(Downcast(std::move(func)), key, value); } else { - LOG(FATAL) << "Do not support function type " << func->GetTypeKey(); + TVM_FFI_THROW(InternalError) + << "Do not support function type " << func->GetTypeKey(); } }) .def("ir.BaseFuncWithAttrs", @@ -63,7 +64,7 @@ TVM_FFI_STATIC_INIT_BLOCK() { if (func->IsInstance()) { return WithAttrs(Downcast(std::move(func)), attr_map); } - LOG(FATAL) << "Do not support function type " << func->GetTypeKey(); + TVM_FFI_THROW(InternalError) << "Do not support function type " << func->GetTypeKey(); TVM_FFI_UNREACHABLE(); }) .def("ir.BaseFuncWithoutAttr", @@ -74,7 +75,8 @@ TVM_FFI_STATIC_INIT_BLOCK() { } else if (func->IsInstance()) { return WithoutAttr(Downcast(std::move(func)), key); } else { - LOG(FATAL) << "Do not support function type " << func->GetTypeKey(); + TVM_FFI_THROW(InternalError) + << "Do not support function type " << func->GetTypeKey(); TVM_FFI_UNREACHABLE(); } }); diff --git a/src/ir/global_var_supply.cc b/src/ir/global_var_supply.cc index 115eba152948..700c3ef84038 100644 --- a/src/ir/global_var_supply.cc +++ b/src/ir/global_var_supply.cc @@ -62,7 +62,7 @@ GlobalVarSupply::GlobalVarSupply(const IRModule module) void GlobalVarSupplyNode::ReserveGlobalVar(const GlobalVar& var, bool allow_conflict) { name_supply_->ReserveName(var->name_hint, false); if (!allow_conflict) { - ICHECK(name_to_var_map_.count(var->name_hint) == 0) + TVM_FFI_ICHECK(name_to_var_map_.count(var->name_hint) == 0) << "GlobalVar " << var << " conflicts by name in this supply."; } name_to_var_map_[var->name_hint] = var; @@ -87,7 +87,7 @@ GlobalVar GlobalVarSupplyNode::UniqueGlobalFor(const ffi::String& name, bool add GlobalVar GlobalVarSupplyNode::FreshGlobal(ffi::String name, bool add_prefix) { ffi::String final_name = name_supply_->FreshName(name, add_prefix); - ICHECK(name_to_var_map_.find(final_name) == name_to_var_map_.end()) + TVM_FFI_ICHECK(name_to_var_map_.find(final_name) == name_to_var_map_.end()) << "GlobalVar already exists for name " << final_name; GlobalVar var = GlobalVar(final_name); name_to_var_map_.emplace(final_name, var); diff --git a/src/ir/instrument.cc b/src/ir/instrument.cc index b969470caa89..8d1dd2ecf54f 100644 --- a/src/ir/instrument.cc +++ b/src/ir/instrument.cc @@ -248,7 +248,7 @@ void PassProfile::EnterPass(ffi::String name) { void PassProfile::ExitPass() { PassProfile* cur = PassProfile::Current(); - ICHECK_NE(cur->name, "root") << "mismatched enter/exit for pass profiling"; + TVM_FFI_ICHECK_NE(cur->name, "root") << "mismatched enter/exit for pass profiling"; cur->end = PassProfile::Clock::now(); cur->duration = std::chrono::duration_cast(cur->end - cur->start); PassProfileThreadLocalStoreGet()->profile_stack.pop(); @@ -265,7 +265,8 @@ PassProfile* PassProfile::Current() { ffi::String RenderPassProfiles() { PassProfileThreadLocalEntry* entry = PassProfileThreadLocalStoreGet(); - CHECK(entry->profile_stack.empty()) << "cannot print pass profile while still in a pass!"; + TVM_FFI_ICHECK(entry->profile_stack.empty()) + << "cannot print pass profile while still in a pass!"; if (entry->root.children.empty()) { LOG(WARNING) << "no passes have been profiled, did you enable pass profiling?"; diff --git a/src/ir/module.cc b/src/ir/module.cc index e7f251103814..04e3026b0f11 100644 --- a/src/ir/module.cc +++ b/src/ir/module.cc @@ -49,7 +49,7 @@ IRModule::IRModule(tvm::ffi::Map functions, SourceMap sourc for (const auto& kv : n->functions) { // set global var map - ICHECK(n->global_var_map_.count(kv.first->name_hint) == 0) + TVM_FFI_ICHECK(n->global_var_map_.count(kv.first->name_hint) == 0) << "Duplicate global function name " << kv.first->name_hint; n->global_var_map_.Set(kv.first->name_hint, kv.first); } @@ -118,7 +118,7 @@ GlobalVar IRModuleNode::GetGlobalVar(const ffi::String& name) const { auto it = global_var_map_.find(name); if (it == global_var_map_.end()) { std::ostringstream msg; - msg << "ValueError: Cannot find global var \"" << name << "\" in the Module\n" + msg << "Cannot find global var \"" << name << "\" in the Module\n" << "candidates are: ["; int counter = 0; for (auto kv : global_var_map_) { @@ -128,7 +128,7 @@ GlobalVar IRModuleNode::GetGlobalVar(const ffi::String& name) const { msg << "\"" << kv.first << "\""; } msg << "]"; - LOG(FATAL) << msg.str(); + TVM_FFI_THROW(ValueError) << msg.str(); } return (*it).second; } @@ -154,9 +154,10 @@ void IRModuleNode::AddUnchecked(const GlobalVar& var, const BaseFunc& func) { auto it = global_var_map_.find(var->name_hint); if (it != global_var_map_.end()) { - ICHECK_EQ((*it).second, var); + TVM_FFI_ICHECK_EQ((*it).second, var); } else { - ICHECK(global_var_map_.count(var->name_hint) == 0) << "Duplicate global function name " << var; + TVM_FFI_ICHECK(global_var_map_.count(var->name_hint) == 0) + << "Duplicate global function name " << var; } global_var_map_.Set(var->name_hint, var); @@ -179,7 +180,7 @@ void IRModuleNode::Remove(const GlobalVar& var) { BaseFunc IRModuleNode::Lookup(const GlobalVar& var) const { auto it = functions.find(var); - ICHECK(it != functions.end()) << "There is no definition of " << var; + TVM_FFI_ICHECK(it != functions.end()) << "There is no definition of " << var; return (*it).second; } @@ -240,8 +241,9 @@ TVM_FFI_STATIC_INIT_BLOCK() { } else if (attrs.as()) { return tvm::DictAttrs(Downcast>(attrs)); } else { - LOG(FATAL) << "Expected attrs argument to be either DictAttrs or " - "ffi::Map"; + TVM_FFI_THROW(InternalError) + << "Expected attrs argument to be either DictAttrs or " + "ffi::Map"; } }(); @@ -255,7 +257,7 @@ TVM_FFI_STATIC_INIT_BLOCK() { }) .def("ir.Module_Add", [](IRModule mod, GlobalVar var, ObjectRef val, bool update) -> IRModule { - ICHECK(val->IsInstance()); + TVM_FFI_ICHECK(val->IsInstance()); mod->Add(var, Downcast(val), update); return mod; }) @@ -267,8 +269,7 @@ TVM_FFI_STATIC_INIT_BLOCK() { } else if (auto opt = var.as()) { return mod->GetGlobalVar(opt.value()); } else { - LOG(FATAL) << "InternalError: " - << "Variant didn't contain any of the allowed types"; + TVM_FFI_THROW(InternalError) << "Variant didn't contain any of the allowed types"; } }(); mod->Remove(gvar); @@ -281,8 +282,7 @@ TVM_FFI_STATIC_INIT_BLOCK() { } else if (auto opt = var.as()) { return mod->global_var_map_.count(opt.value()); } else { - LOG(FATAL) << "InternalError: " - << "Variant didn't contain any of the allowed types"; + TVM_FFI_THROW(InternalError) << "Variant didn't contain any of the allowed types"; } }) .def_method("ir.Module_GetGlobalVar", &IRModuleNode::GetGlobalVar) diff --git a/src/ir/op.cc b/src/ir/op.cc index 514b45c65ad0..825437c216d7 100644 --- a/src/ir/op.cc +++ b/src/ir/op.cc @@ -46,7 +46,7 @@ using OpRegistry = AttrRegistry; // find operator by name const Op& Op::Get(const ffi::String& name) { const OpRegEntry* reg = OpRegistry::Global()->Get(name); - ICHECK(reg != nullptr) << "AttributeError: Operator " << name << " is not registered"; + TVM_FFI_CHECK(reg != nullptr, AttributeError) << "Operator " << name << " is not registered"; return reg->op(); } @@ -109,8 +109,8 @@ TVM_FFI_STATIC_INIT_BLOCK() { .def("ir.RegisterOp", [](ffi::String op_name, ffi::String descr) { const OpRegEntry* reg = OpRegistry::Global()->Get(op_name); - ICHECK(reg == nullptr) - << "AttributeError: Operator " << op_name << " is registered before"; + TVM_FFI_CHECK(reg == nullptr, AttributeError) + << "Operator " << op_name << " is registered before"; auto& op = OpRegistry::Global()->RegisterOrGet(op_name).set_name(); op.describe(descr); }) @@ -141,7 +141,7 @@ TVM_FFI_STATIC_INIT_BLOCK() { if (attr_key == "num_inputs" && plevel > 128) { reg.set_num_inputs(value.cast()); } else if (attr_key == "attrs_type_key" && plevel > 128) { - LOG(FATAL) << "attrs type key no longer supported"; + TVM_FFI_THROW(InternalError) << "attrs type key no longer supported"; } else { reg.set_attr(attr_key, value, plevel); } diff --git a/src/ir/replace_global_vars.cc b/src/ir/replace_global_vars.cc index 98b5b74c42cd..2a3517b4d815 100644 --- a/src/ir/replace_global_vars.cc +++ b/src/ir/replace_global_vars.cc @@ -80,7 +80,7 @@ IRModule ModuleReplaceGlobalVars( } else if (auto str = before.as()) { gvar_before = mod->GetGlobalVar(str.value()); } else { - LOG(FATAL) + TVM_FFI_THROW(InternalError) << "ffi::Variant must contain either ffi::String or GlobalVar"; } @@ -91,7 +91,7 @@ IRModule ModuleReplaceGlobalVars( gvar_after = gvar_before; gvar_after.CopyOnWrite()->name_hint = str.value(); } else { - LOG(FATAL) + TVM_FFI_THROW(InternalError) << "ffi::Variant must contain either ffi::String or GlobalVar"; } diff --git a/src/ir/source_map.cc b/src/ir/source_map.cc index 521b02db44b5..7b94890623fd 100644 --- a/src/ir/source_map.cc +++ b/src/ir/source_map.cc @@ -90,9 +90,9 @@ Span::Span(SourceName source_name, int line, int end_line, int column, int end_c } Span Span::Merge(const Span& other) const { - ICHECK(this->defined() && other.defined()) << "Span::Merge: both spans must be defined"; + TVM_FFI_ICHECK(this->defined() && other.defined()) << "Span::Merge: both spans must be defined"; - ICHECK((*this)->source_name == other->source_name); + TVM_FFI_ICHECK((*this)->source_name == other->source_name); return Span((*this)->source_name, std::min((*this)->line, other->line), std::max((*this)->end_line, other->end_line), std::min((*this)->column, other->column), @@ -203,7 +203,7 @@ Source::Source(SourceName src_name, std::string source) { tvm::ffi::String Source::GetLine(int line) { VLOG(1) << "Source::GetLine: line=" << line; - ICHECK(line - 1 < static_cast((*this)->line_map.size())) + TVM_FFI_ICHECK(line - 1 < static_cast((*this)->line_map.size())) << "requested line: " << line << "at index: " << (line - 1) << "line_map size: " << (*this)->line_map.size() << "source: " << (*this)->source; diff --git a/src/ir/transform.cc b/src/ir/transform.cc index c64770654471..148918be8eee 100644 --- a/src/ir/transform.cc +++ b/src/ir/transform.cc @@ -67,8 +67,8 @@ void PassContext::EnterWithScope() { void PassContext::ExitWithScope() { PassContextThreadLocalEntry* entry = PassContextThreadLocalStoreGet(); - ICHECK(!entry->context_stack.empty()); - ICHECK(entry->context_stack.top().same_as(*this)); + TVM_FFI_ICHECK(!entry->context_stack.empty()); + TVM_FFI_ICHECK(entry->context_stack.top().same_as(*this)); entry->context_stack.pop(); InstrumentExitPassContext(); @@ -107,7 +107,7 @@ class PassConfigManager { public: void Register(std::string key, ffi::String value_type_str, std::function legalization) { - ICHECK_EQ(key2vtype_.count(key), 0U); + TVM_FFI_ICHECK_EQ(key2vtype_.count(key), 0U); ValueTypeInfo info; info.type_str = value_type_str; info.legalization = legalization; @@ -121,22 +121,21 @@ class PassConfigManager { auto it = key2vtype_.find(key); if (it == key2vtype_.end()) { std::ostringstream os; - os << "AttributeError: Invalid config option \'" << key << "\' candidates are:"; + os << "Invalid config option \'" << key << "\' candidates are:"; int counter = 0; for (const auto& [key, value] : key2vtype_) { os << ' '; if (counter++ != 0) os << ','; os << key; } - LOG(FATAL) << os.str(); + TVM_FFI_THROW(AttributeError) << os.str(); } const auto& info = it->second; - ICHECK(value != nullptr) << "AttributeError: " << key << " is None"; + TVM_FFI_CHECK(value != nullptr, AttributeError) << key << " is None"; - ICHECK(info.legalization) << "AttributeError: " - << "Config option \'" << key - << "\' was defined without a legalization function."; + TVM_FFI_CHECK(info.legalization, AttributeError) + << "Config option \'" << key << "\' was defined without a legalization function."; auto legalized = info.legalization(value); if (!legalized.same_as(value)) { update.emplace_back(key, legalized); @@ -294,7 +293,7 @@ IRModule Pass::operator()(IRModule mod) const { IRModule Pass::operator()(IRModule mod, const PassContext& pass_ctx) const { const PassNode* node = operator->(); - ICHECK(node != nullptr); + TVM_FFI_ICHECK(node != nullptr); const PassInfo& pass_info = node->Info(); if (!pass_ctx.InstrumentBeforePass(mod, pass_info)) { DLOG(INFO) << "Skipping pass : " << pass_info->name @@ -404,20 +403,20 @@ IRModule ModulePassNode::operator()(IRModule mod, const PassContext& pass_ctx) c pass_ctx->diag_ctx = previous; } - ICHECK(pass_ctx->diag_ctx) + TVM_FFI_ICHECK(pass_ctx->diag_ctx) << "The diagnostic context was set at the top of this block this is a bug."; const PassInfo& pass_info = Info(); - ICHECK(mod.defined()) << "The input module must be set."; + TVM_FFI_ICHECK(mod.defined()) << "The input module must be set."; VLOG_CONTEXT << pass_info->name; VLOG(0) << "Executing module pass with opt level: " << pass_info->opt_level; mod = pass_func(std::move(mod), pass_ctx); - ICHECK(mod.defined()) << "The return value of a module pass must be set."; + TVM_FFI_ICHECK(mod.defined()) << "The return value of a module pass must be set."; - ICHECK(pass_ctx->diag_ctx) + TVM_FFI_ICHECK(pass_ctx->diag_ctx) << "The diagnostic context was set at the top of this block this is a bug."; pass_ctx->diag_ctx.value().Render(); @@ -450,8 +449,8 @@ void SequentialNode::ResolveDependency(const IRModule& mod) { // 1. Consider the required passes for each pass. // 2. Only resolve the enabled passes. // 3. Build a dependency graph. Probably we need to update the pass list. - LOG(FATAL) << "Pass dependency has not been resolved yet." - << "\n"; + TVM_FFI_THROW(InternalError) << "Pass dependency has not been resolved yet." + << "\n"; } Pass GetPass(const ffi::String& pass_name) { @@ -461,7 +460,7 @@ Pass GetPass(const ffi::String& pass_name) { } else { f = tvm::ffi::Function::GetGlobal("transform." + pass_name); } - ICHECK(f.has_value()) << "Cannot use " << pass_name << " to create the pass"; + TVM_FFI_ICHECK(f.has_value()) << "Cannot use " << pass_name << " to create the pass"; return (*f)().cast(); } @@ -471,7 +470,7 @@ Pass GetPass(const ffi::String& pass_name) { IRModule SequentialNode::operator()(IRModule mod, const PassContext& pass_ctx) const { for (const Pass& pass : passes) { VLOG(0) << "Running pass " << pass->Info()->name; - ICHECK(pass.defined()) << "Found undefined pass for optimization."; + TVM_FFI_ICHECK(pass.defined()) << "Found undefined pass for optimization."; const PassInfo& pass_info = pass->Info(); if (!pass_ctx.PassEnabled(pass_info)) { VLOG(0) << "skipping disabled pass '" << pass_info->name << "'"; diff --git a/src/node/attr_registry.h b/src/node/attr_registry.h index fee7eeb26cab..b956cb5616e9 100644 --- a/src/node/attr_registry.h +++ b/src/node/attr_registry.h @@ -105,10 +105,11 @@ class AttrRegistry { op_map->data_.resize(index + 1, std::make_pair(Any(), 0)); } std::pair& p = op_map->data_[index]; - ICHECK(p.second != plevel) << "Attribute " << attr_name << " of " << key->AttrRegistryName() - << " is already registered with same plevel=" << plevel; - ICHECK(value != nullptr) << "Registered packed_func is Null for " << attr_name - << " of operator " << key->AttrRegistryName(); + TVM_FFI_ICHECK(p.second != plevel) + << "Attribute " << attr_name << " of " << key->AttrRegistryName() + << " is already registered with same plevel=" << plevel; + TVM_FFI_ICHECK(value != nullptr) << "Registered packed_func is Null for " << attr_name + << " of operator " << key->AttrRegistryName(); if (p.second < plevel && value != nullptr) { op_map->data_[index] = std::make_pair(value, plevel); } @@ -138,7 +139,7 @@ class AttrRegistry { const AttrRegistryMapContainerMap& GetAttrMap(const ffi::String& attr_name) { auto it = attrs_.find(attr_name); if (it == attrs_.end()) { - LOG(FATAL) << "Attribute \'" << attr_name << "\' is not registered"; + TVM_FFI_THROW(InternalError) << "Attribute \'" << attr_name << "\' is not registered"; } return *it->second.get(); } diff --git a/src/node/script_printer.cc b/src/node/script_printer.cc index 36c61d78b345..09413ba007e3 100644 --- a/src/node/script_printer.cc +++ b/src/node/script_printer.cc @@ -128,10 +128,10 @@ PrinterConfig::PrinterConfig(ffi::Map config_dict) { } // Checking prefixes if they are valid Python identifiers. - CHECK(IsIdentifier(n->ir_prefix)) << "Invalid `ir_prefix`: " << n->ir_prefix; - CHECK(IsIdentifier(n->tir_prefix)) << "Invalid `tir_prefix`: " << n->tir_prefix; - CHECK(IsIdentifier(n->relax_prefix)) << "Invalid `relax_prefix`: " << n->relax_prefix; - CHECK(n->module_alias.empty() || IsIdentifier(n->module_alias)) + TVM_FFI_ICHECK(IsIdentifier(n->ir_prefix)) << "Invalid `ir_prefix`: " << n->ir_prefix; + TVM_FFI_ICHECK(IsIdentifier(n->tir_prefix)) << "Invalid `tir_prefix`: " << n->tir_prefix; + TVM_FFI_ICHECK(IsIdentifier(n->relax_prefix)) << "Invalid `relax_prefix`: " << n->relax_prefix; + TVM_FFI_ICHECK(n->module_alias.empty() || IsIdentifier(n->module_alias)) << "Invalid `module_alias`: " << n->module_alias; this->data_ = std::move(n); diff --git a/src/node/structural_hash.cc b/src/node/structural_hash.cc index 6f7bca7cc517..f32f0756c04d 100644 --- a/src/node/structural_hash.cc +++ b/src/node/structural_hash.cc @@ -76,7 +76,7 @@ TVM_FFI_STATIC_INIT_BLOCK() { support::Base64InStream b64strm(&mstrm); b64strm.InitPosition(); runtime::Tensor temp; - ICHECK(temp.Load(&b64strm)); + TVM_FFI_ICHECK(temp.Load(&b64strm)); return temp; }); } diff --git a/src/relax/analysis/analysis.cc b/src/relax/analysis/analysis.cc index a61d548443a3..69f3dad32022 100644 --- a/src/relax/analysis/analysis.cc +++ b/src/relax/analysis/analysis.cc @@ -187,7 +187,7 @@ ffi::Optional FindImpureCall(const Expr& expr, const ffi::Optional& }; if (own_name) { - ICHECK(own_name.value().as() || own_name.value().as()) + TVM_FFI_ICHECK(own_name.value().as() || own_name.value().as()) << "Must pass a Var or GlobalVar for own_name"; } diff --git a/src/relax/analysis/graph_partitioner.cc b/src/relax/analysis/graph_partitioner.cc index d68626160fe9..0eb4eaf43792 100644 --- a/src/relax/analysis/graph_partitioner.cc +++ b/src/relax/analysis/graph_partitioner.cc @@ -64,9 +64,9 @@ DominatorTree::Node* DominatorTree::LeastCommonAncestor( } auto get_node = [&](const IndexedForwardGraph::Edge& edge) { size_t oindex = edge.node->index; - ICHECK_LT(oindex, nodes.size()); + TVM_FFI_ICHECK_LT(oindex, nodes.size()); Node* onode = nodes[oindex]; - ICHECK(onode != nullptr); + TVM_FFI_ICHECK(onode != nullptr); return onode; }; Node* parent = get_node(link->value); @@ -133,7 +133,7 @@ bool GraphPartitioner::CheckPath_(IndexedForwardGraph::Node* src, IndexedForward if (visited_.count(src)) return true; visited_.insert(src); Group* gnode = groups_[src->index]; - ICHECK(gnode != nullptr); + TVM_FFI_ICHECK(gnode != nullptr); gnode = gnode->FindRoot(); if (!fcond(gnode->pattern, src == sink)) return false; if (src == sink) return true; @@ -146,9 +146,9 @@ bool GraphPartitioner::CheckPath_(IndexedForwardGraph::Node* src, IndexedForward template bool GraphPartitioner::CheckPath(IndexedForwardGraph::Node* src, IndexedForwardGraph::Node* sink, F fcond) { - ICHECK(!src->extern_ref); + TVM_FFI_ICHECK(!src->extern_ref); visited_.clear(); - ICHECK(src != sink); + TVM_FFI_ICHECK(src != sink); for (auto link = src->outputs.head; link != nullptr; link = link->next) { if (!CheckPath_(link->value.node, sink, fcond)) return false; } @@ -157,7 +157,7 @@ bool GraphPartitioner::CheckPath(IndexedForwardGraph::Node* src, IndexedForwardG OpPatternKind CombinePattern(OpPatternKind lhs, OpPatternKind rhs) { if (lhs > kBroadcast && rhs > kBroadcast) { - LOG(FATAL) << "Cannot merge two complex group together"; + TVM_FFI_THROW(InternalError) << "Cannot merge two complex group together"; } if (lhs > rhs) return lhs; return rhs; @@ -173,7 +173,7 @@ void GraphPartitioner::MergeFromTo(Group* child, Group* parent) { child->parent = parent; // update anchor ref and pattern if (child->anchor_ref != nullptr) { - ICHECK(parent->anchor_ref == nullptr); + TVM_FFI_ICHECK(parent->anchor_ref == nullptr); parent->anchor_ref = child->anchor_ref; parent->pattern = CombinePattern(child->pattern, parent->pattern); } @@ -189,7 +189,7 @@ void GraphPartitioner::CommitFuse_(IndexedForwardGraph::Node* src, IndexedForwar if (visited_.count(src)) return; visited_.insert(src); Group* gnode = groups_[src->index]; - ICHECK(gnode != nullptr); + TVM_FFI_ICHECK(gnode != nullptr); // merge the current group to the parent if possible. MergeFromTo(gnode, target); for (auto link = src->outputs.head; link != nullptr; link = link->next) { @@ -200,7 +200,7 @@ void GraphPartitioner::CommitFuse_(IndexedForwardGraph::Node* src, IndexedForwar void GraphPartitioner::CommitFuse(IndexedForwardGraph::Node* src, IndexedForwardGraph::Node* sink) { Group* target = groups_[sink->index]; visited_.clear(); - ICHECK(src != sink); + TVM_FFI_ICHECK(src != sink); CommitFuse_(src, sink, target); } @@ -209,7 +209,7 @@ size_t GraphPartitioner::CountNodesUptoSink_(IndexedForwardGraph::Node* src, if (src == sink || visited_.count(src)) return 0; visited_.insert(src); Group* gnode = groups_[src->index]; - ICHECK(gnode != nullptr); + TVM_FFI_ICHECK(gnode != nullptr); auto sum = gnode->num_nodes; for (auto link = src->outputs.head; link != nullptr; link = link->next) { sum += CountNodesUptoSink_(link->value.node, sink); @@ -221,7 +221,7 @@ size_t GraphPartitioner::CountFusedNodesWithNewChild(IndexedForwardGraph::Node* IndexedForwardGraph::Node* dom_parent) { Group* target = groups_[dom_parent->index]; visited_.clear(); - ICHECK(child != dom_parent); + TVM_FFI_ICHECK(child != dom_parent); return target->FindRoot()->num_nodes + CountNodesUptoSink_(child, dom_parent); } @@ -229,7 +229,7 @@ size_t GraphPartitioner::CountArgs_(IndexedForwardGraph::Node* src, const IndexedForwardGraph& graph, bool update_postpone) { std::unordered_set visited_groups; Group* gnode = groups_[src->index]; - ICHECK(gnode != nullptr); + TVM_FFI_ICHECK(gnode != nullptr); auto sum = gnode->args_num; visited_groups.insert(gnode->FindRoot()); auto calc_args_number = [this, src, &graph, &visited_groups, @@ -329,7 +329,7 @@ void GraphPartitioner::RunFuse(const IndexedForwardGraph& graph, // auto* graph_node = graph.post_dfs_order[nid]; auto* dom_node = post_dom_tree.nodes[nid]; Group* group_node = groups_[nid]; - ICHECK(group_node != nullptr); + TVM_FFI_ICHECK(group_node != nullptr); postpone_node_ = nullptr; // Check if the fusing of some inputs was postponed if (postponed_fusing_map_.count(graph_node)) { @@ -349,7 +349,7 @@ void GraphPartitioner::RunFuse(const IndexedForwardGraph& graph, // if (group_node->pattern == kOpaque) continue; // no actions needed if the current node have no dominator if (dom_node->parent == nullptr) continue; - ICHECK(!graph_node->extern_ref); + TVM_FFI_ICHECK(!graph_node->extern_ref); size_t dom_parent_gindex = dom_node->parent->gnode->index; // refuse the fusion if too many ops are going to be fused together @@ -397,7 +397,7 @@ void GraphPartitioner::RunFuse(const IndexedForwardGraph& graph, // // Path for OutEWiseFusable: conv2d // Check if the dominator relation is elemwise. if (dom_node->parent != nullptr && dom_node->pattern == kElemWise) { - ICHECK(dom_node->parent->gnode != nullptr); + TVM_FFI_ICHECK(dom_node->parent->gnode != nullptr); // The fuse can be executed if all the intermediate ops are still broadcast. auto fcond = [](OpPatternKind kind, bool is_sink) { return kind <= kBroadcast; }; if (CheckPath(graph_node, dom_node->parent->gnode, fcond)) { @@ -435,7 +435,7 @@ void GraphPartitioner::RunFuse(const IndexedForwardGraph& graph, // } } else { // do nothing. - ICHECK(group_node->pattern == kCommReduce); + TVM_FFI_ICHECK(group_node->pattern == kCommReduce); } } } diff --git a/src/relax/analysis/layout_transformation.cc b/src/relax/analysis/layout_transformation.cc index 7eed4cd4aa9d..38752128b87f 100644 --- a/src/relax/analysis/layout_transformation.cc +++ b/src/relax/analysis/layout_transformation.cc @@ -42,7 +42,7 @@ using namespace tir; /*! \brief Checks if a transformation is bijective affine over the given ranges */ static bool IsBijectiveAffine(const IndexMap& m, const ffi::Array& ranges) { ffi::Map input_iters; - ICHECK_EQ(m->initial_indices.size(), ranges.size()); + TVM_FFI_ICHECK_EQ(m->initial_indices.size(), ranges.size()); for (size_t i = 0; i < ranges.size(); i++) { input_iters.Set(m->initial_indices[i], ranges[i]); } @@ -113,7 +113,7 @@ class IndexAnalyzer : public ExprVisitor { */ using SpatialLayout = ffi::Array>; static SpatialLayout GetSpatialLayout(const arith::IterMapResult& iter_map_result) { - ICHECK(!iter_map_result->indices.empty()); + TVM_FFI_ICHECK(!iter_map_result->indices.empty()); SpatialLayout result; for (const arith::IterSumExpr& index : iter_map_result->indices) { IndexAnalyzer index_analyzer; @@ -159,7 +159,7 @@ static bool IsSequentialAccess(const SpatialLayout& iterators, for (const auto& i : iterators) { if (!i.defined()) continue; auto it = iter_to_block_index.find(i.value()); - ICHECK(it != iter_to_block_index.end()); + TVM_FFI_ICHECK(it != iter_to_block_index.end()); int blk_index = it->second; if (blk_index <= last_value) return false; last_value = blk_index; @@ -231,7 +231,7 @@ static ffi::Optional InferLayoutTransformation(const SpatialLayout& sr auto initial_indices_it = initial_indices.begin(); VarSet initial_indices_var_set; for (const auto& i : src_spatial_layout) { - ICHECK(i.defined()); + TVM_FFI_ICHECK(i.defined()); if (tgt_var_set.count(i.value())) { initial_indices_var_set.insert(*initial_indices_it); initial_indices_it++; @@ -245,7 +245,7 @@ static ffi::Optional InferLayoutTransformation(const SpatialLayout& sr while (final_indices_it != final_indices.end()) { // Collect all the vars used in this final index. ffi::Array used_vars = tir::UndefinedVars(*final_indices_it); - ICHECK(!used_vars.empty()) + TVM_FFI_ICHECK(!used_vars.empty()) << "IndexMap expression must always contain tir::Var nodes but found none in: " << *final_indices_it; @@ -283,7 +283,7 @@ static ffi::Optional InferLayoutTransformation(const SpatialLayout& sr // spatial layout. VarSet src_var_set; for (const auto& i : src_spatial_layout) { - ICHECK(i.defined()); + TVM_FFI_ICHECK(i.defined()); src_var_set.insert(i.value()); } @@ -325,7 +325,7 @@ class BlockAnalyzer : public StmtExprVisitor { write_transformation_(write_transformation), block_(block), buffer_transformation_cache_(transformation_cache) { - ICHECK(block_->writes.size() == 1); + TVM_FFI_ICHECK(block_->writes.size() == 1); auto write_buffer = block_->writes[0]->buffer; ComputeBlockSpatialDomain(); @@ -544,15 +544,16 @@ class BlockAnalyzer : public StmtExprVisitor { class PrimFuncAnalyzer : public StmtExprVisitor { public: explicit PrimFuncAnalyzer(const PrimFunc& func, ffi::Array write_transformations) { - ICHECK_LE(write_transformations.size(), func->params.size()) + TVM_FFI_ICHECK_LE(write_transformations.size(), func->params.size()) << "Incompatible PrimFunc and write_transformations"; size_t first_write_index = func->params.size() - write_transformations.size(); for (size_t i = 0; i < write_transformations.size(); ++i) { auto param = func->params[first_write_index + i]; ffi::Optional param_buf = func->buffer_map.Get(param); - ICHECK(param_buf.defined()); - ICHECK_EQ(param_buf.value()->shape.size(), write_transformations[i]->initial_indices.size()) + TVM_FFI_ICHECK(param_buf.defined()); + TVM_FFI_ICHECK_EQ(param_buf.value()->shape.size(), + write_transformations[i]->initial_indices.size()) << "Mismatch between output buffer shape and index map"; buffer_transformation_cache_.Set(param_buf.value(), write_transformations[i]); } @@ -595,7 +596,7 @@ class PrimFuncAnalyzer : public StmtExprVisitor { // BlockAnalyzer makes sure that it does not propose transformation for a buffer for which a // transformation has already been proposed by other blocks or by write_transformations which // are input to this analysis. - ICHECK_EQ(buffer_transformation_cache_.count(buffer), 0); + TVM_FFI_ICHECK_EQ(buffer_transformation_cache_.count(buffer), 0); buffer_transformation_cache_.Set(buffer, index_map); block_to_buffer_[block].push_back(buffer); } diff --git a/src/relax/analysis/struct_info_analysis.cc b/src/relax/analysis/struct_info_analysis.cc index 3952b1ce4a6e..cd951896d821 100644 --- a/src/relax/analysis/struct_info_analysis.cc +++ b/src/relax/analysis/struct_info_analysis.cc @@ -105,7 +105,7 @@ StructInfo StructInfoFromType(const Type& type) { // TODO(relax-team): Maybe add purity into the type as well return FuncStructInfo(params, ret, true, func_type->span); } else { - LOG(FATAL) << "Unsupported type: " << type; + TVM_FFI_THROW(InternalError) << "Unsupported type: " << type; return StructInfo(); } } @@ -222,7 +222,7 @@ class WellDefinedEraser : public StructInfoMutator, } has_undefined_ = has_undefined_ || !ret.defined(); if (ret.defined()) { - ICHECK(ret.as() || ret.as()) + TVM_FFI_ICHECK(ret.as() || ret.as()) << "Only allow Expr in StructInfo to be ShapeExpr or Var"; } return ret.value_or(ffi::GetRef(var)); @@ -240,7 +240,8 @@ class WellDefinedEraser : public StructInfoMutator, if (value->IsInstance()) { return tvm::cast(DataType::Int(64), value); } - ICHECK(value.dtype() == DataType::Int(64)) << "Can only provide i64 expressions in shape"; + TVM_FFI_ICHECK(value.dtype() == DataType::Int(64)) + << "Can only provide i64 expressions in shape"; return value; } else { return ffi::GetRef(var); @@ -933,7 +934,7 @@ class CallRetStructInfoDeriver : public StructInfoBaseChecker { } auto lhs_shape = lhs.as(); auto rhs_shape = rhs.as(); - ICHECK(lhs_shape) << "lhs must have a shape"; + TVM_FFI_ICHECK(lhs_shape) << "lhs must have a shape"; if (!rhs_shape) return BaseCheckResult::kFailL2; return ShapeMatchCheck(lhs_shape->values, rhs_shape->values); } @@ -1206,7 +1207,8 @@ class TIRVarsDetector : public StructInfoVisitor { RecordTIRVar(tir_var); } } else { - LOG(FATAL) << "Invalid value for VarType enum, " << static_cast(collection_type); + TVM_FFI_THROW(InternalError) + << "Invalid value for VarType enum, " << static_cast(collection_type); } } diff --git a/src/relax/analysis/tir_op_pattern_kind.cc b/src/relax/analysis/tir_op_pattern_kind.cc index 3a3e0e6697bc..19ecb4ca68da 100644 --- a/src/relax/analysis/tir_op_pattern_kind.cc +++ b/src/relax/analysis/tir_op_pattern_kind.cc @@ -382,7 +382,7 @@ bool HasReshapePattern(const PrimFunc& func) { const SBlock& block = block_realize->block; const ffi::Array& block_iter = block->iter_vars; const ffi::Array& iter_values = block_realize->iter_values; - ICHECK_EQ(block_iter.size(), iter_values.size()); + TVM_FFI_ICHECK_EQ(block_iter.size(), iter_values.size()); int n_iter = block_iter.size(); for (int i = 0; i < n_iter; ++i) { // To detect the reshape pattern, we require each block iter to be data-parallel. @@ -431,7 +431,7 @@ bool HasReshapePattern(const PrimFunc& func) { // access (e.g., buf[ax0, ax1, ax2]). auto f_calc_flattened_idx = [&](const Buffer& buffer, const ffi::Array& indices) { - ICHECK_EQ(indices.size(), buffer->shape.size()); + TVM_FFI_ICHECK_EQ(indices.size(), buffer->shape.size()); int ndim = indices.size(); PrimExpr idx = 0; for (int i = 0; i < ndim; ++i) { @@ -495,7 +495,7 @@ bool HasReshapePattern(const PrimFunc& func) { /*check_level=*/arith::IterMapLevel::Surjective, /*analyzer=*/&this->ana_, /*simplify_trivial_iterators=*/true); - ICHECK_EQ(simplify_res.size(), 1); + TVM_FFI_ICHECK_EQ(simplify_res.size(), 1); if (simplify_res[0].same_as(fused_var)) { this->is_reshape_ = true; @@ -535,7 +535,7 @@ bool HasReshapePattern(const PrimFunc& func) { // To detect the reshape pattern, we require each For to have // either another For or a BlockRealize as body. - ICHECK(func->body->IsInstance()); + TVM_FFI_ICHECK(func->body->IsInstance()); return ReshapeDetector::Detect(src_buffer, dst_buffer, func->body); } diff --git a/src/relax/analysis/udchain.cc b/src/relax/analysis/udchain.cc index fcd628f606cf..188180384397 100644 --- a/src/relax/analysis/udchain.cc +++ b/src/relax/analysis/udchain.cc @@ -63,7 +63,7 @@ class UDChain : relax::ExprVisitor { ffi::Optional cur_user_; void VisitBinding_(const VarBindingNode* binding) override { - CHECK(!bound_values.count(binding->var)) + TVM_FFI_ICHECK(!bound_values.count(binding->var)) << "Variable " << binding->var << " was defined multiple times"; bound_values.Set(binding->var, binding->value); @@ -104,7 +104,8 @@ class UDChain : relax::ExprVisitor { } void DefineVar(const Var& var) { - CHECK(!usage_map.count(var)) << "Variable " << var << " was used before its definition"; + TVM_FFI_ICHECK(!usage_map.count(var)) + << "Variable " << var << " was used before its definition"; usage_map[var] = {}; } }; diff --git a/src/relax/analysis/well_formed.cc b/src/relax/analysis/well_formed.cc index f18143cdd291..2c1a42fbe843 100644 --- a/src/relax/analysis/well_formed.cc +++ b/src/relax/analysis/well_formed.cc @@ -105,8 +105,8 @@ class WellFormedChecker : public relax::ExprVisitor, } else if (const auto* func = obj.as()) { well_formed_checker.VisitExpr(ffi::GetRef(func)); } else { - LOG(FATAL) << "Unreachable, " - << "variant did not contain any of the allowed types"; + TVM_FFI_THROW(InternalError) << "Unreachable, " + << "variant did not contain any of the allowed types"; } return well_formed_checker.well_formed_; } @@ -250,7 +250,7 @@ class WellFormedChecker : public relax::ExprVisitor, // first populate defs in params WithMode(VisitMode::kMatchVarDef, [&]() { - ICHECK(mode_ == VisitMode::kMatchVarDef); + TVM_FFI_ICHECK(mode_ == VisitMode::kMatchVarDef); for (Var param : op->params) { relax::StructInfoVisitor::VisitStructInfo(GetStructInfo(param)); } @@ -574,7 +574,7 @@ class WellFormedChecker : public relax::ExprVisitor, void VisitStructInfo_(const FuncStructInfoNode* op) final { if (op->params.defined()) { WithMode(VisitMode::kMatchVarDef, [&]() { - ICHECK(mode_ == VisitMode::kMatchVarDef); + TVM_FFI_ICHECK(mode_ == VisitMode::kMatchVarDef); for (StructInfo param : op->params.value()) { this->VisitStructInfo(param); } diff --git a/src/relax/backend/adreno/annotate_custom_storage.cc b/src/relax/backend/adreno/annotate_custom_storage.cc index f42c5c456124..861e57aeb7d5 100644 --- a/src/relax/backend/adreno/annotate_custom_storage.cc +++ b/src/relax/backend/adreno/annotate_custom_storage.cc @@ -260,7 +260,7 @@ using tvm::tir::Buffer; static ffi::Array GetShapeFromTensorStructInfo(const TensorStructInfo& tensor_sinfo) { auto shape = tensor_sinfo->GetShape(); - ICHECK(shape.defined()); + TVM_FFI_ICHECK(shape.defined()); return shape.value(); } @@ -505,7 +505,7 @@ class CollectProducerScopeInfo : public ExprVisitor { auto* op_ptr = call->op.as(); Op op = ffi::GetRef(op_ptr); - ICHECK(op_map_infer_struct_info_.count(op)) + TVM_FFI_ICHECK(op_map_infer_struct_info_.count(op)) << " Cannot find the FInferStructInfo attribute registered to op: " << op->name; out_sinfo = op_map_infer_struct_info_[op](ffi::GetRef(call), builder_); } @@ -547,7 +547,7 @@ class CollectProducerScopeInfo : public ExprVisitor { VDevice(target_, 0, scope[0])); } - ICHECK(out_sinfo->IsInstance()) + TVM_FFI_ICHECK(out_sinfo->IsInstance()) << "Expect output struct info of call_tir to be either TupleStructInfo or " "TensorStructInfo, but got " << out_sinfo; @@ -555,7 +555,7 @@ class CollectProducerScopeInfo : public ExprVisitor { const auto& tuple_sinfo = Downcast(out_sinfo); ffi::Array sinfo_fields; for (const auto& si : tuple_sinfo->fields) { - ICHECK(si->IsInstance()) + TVM_FFI_ICHECK(si->IsInstance()) << "Fields of TupleStructInfo must be TensorStructInfo for call_tir " "output structinfo, but got " << si; @@ -649,7 +649,7 @@ class DefineVDevice : ExprMutator { updated_ret_sinfo = TensorStructInfo(shape, dtype, vdev_global); } } else { - ICHECK(updated_ret_sinfo->IsInstance()) + TVM_FFI_ICHECK(updated_ret_sinfo->IsInstance()) << "Expect output struct info of call_tir to be either TupleStructInfo or " "TensorStructInfo, but got " << updated_ret_sinfo; @@ -657,7 +657,7 @@ class DefineVDevice : ExprMutator { const auto& tuple_sinfo = Downcast(updated_ret_sinfo); ffi::Array sinfo_fields; for (const auto& si : tuple_sinfo->fields) { - ICHECK(si->IsInstance()) + TVM_FFI_ICHECK(si->IsInstance()) << "Fields of TupleStructInfo must be TensorStructInfo for call_tir " "output structinfo, but got " << si; @@ -720,7 +720,7 @@ class DefineVDevice : ExprMutator { if (auto tsinfo = arg->struct_info_.as()) { if (!tsinfo->vdevice.defined()) { const VDevice& vdev = MakeGlobalVDevice(VDevice(target_, 0, scope)); - CHECK(tsinfo->shape.defined()) << "Shape not defined for a constant tensor ..!"; + TVM_FFI_ICHECK(tsinfo->shape.defined()) << "Shape not defined for a constant tensor ..!"; arg->struct_info_ = TensorStructInfo(tsinfo->shape.value(), tsinfo->dtype, vdev, tsinfo->span); return arg; diff --git a/src/relax/backend/adreno/fold_vdevice_scope_change.cc b/src/relax/backend/adreno/fold_vdevice_scope_change.cc index c59beae78e96..a7103cde9577 100644 --- a/src/relax/backend/adreno/fold_vdevice_scope_change.cc +++ b/src/relax/backend/adreno/fold_vdevice_scope_change.cc @@ -54,21 +54,20 @@ std::tuple)>> auto rewriter = [=](Expr expr, ffi::Map matches) -> Expr { const auto* call_tir = matches[pat_call_tir].as(); - ICHECK(call_tir) << "InternalError: " - << "Match of relax.call_tir operator should produce Call, " - << "but instead produces " << matches[pat_call_tir] << " with type " - << matches[pat_call_tir]->GetTypeKey(); + TVM_FFI_CHECK(call_tir, InternalError) + << "Match of relax.call_tir operator should produce Call, " + << "but instead produces " << matches[pat_call_tir] << " with type " + << matches[pat_call_tir]->GetTypeKey(); const auto* out = matches[pattern_out].as(); - ICHECK(out) << "InternalError: " - << "Match of relax.to_vdevice operator should produce Call, " - << "but instead produces " << matches[pattern_out] << " with type " - << matches[pattern_out]->GetTypeKey(); + TVM_FFI_CHECK(out, InternalError) << "Match of relax.to_vdevice operator should produce Call, " + << "but instead produces " << matches[pattern_out] + << " with type " << matches[pattern_out]->GetTypeKey(); const auto* vdev_attrs = out->attrs.as(); - ICHECK(vdev_attrs) << "InternalError: " - << "Attributes for relax.to_vdevice operator should be ToVDeviceAttrs, " - << "but were instead " << out->attrs << " with type " << out->GetTypeKey(); + TVM_FFI_CHECK(vdev_attrs, InternalError) + << "Attributes for relax.to_vdevice operator should be ToVDeviceAttrs, " + << "but were instead " << out->attrs << " with type " << out->GetTypeKey(); const auto* tir_out_sinfo = call_tir->sinfo_args[0].as(); if (!tir_out_sinfo) return expr; diff --git a/src/relax/backend/contrib/clml/codegen.cc b/src/relax/backend/contrib/clml/codegen.cc index ecd9899eceaf..9c61b3db0cd4 100644 --- a/src/relax/backend/contrib/clml/codegen.cc +++ b/src/relax/backend/contrib/clml/codegen.cc @@ -131,11 +131,11 @@ class OpenCLMLJSONSerializer : public JSONSerializer { std::vector VisitExpr_(const CallNode* call_node) final { // The call must be to an inline "Composite" function const auto* fn_var = call_node->op.as(); - ICHECK(fn_var); + TVM_FFI_ICHECK(fn_var); const auto fn = Downcast(bindings_[ffi::GetRef(fn_var)]); auto opt_composite = fn->GetAttr(attr::kComposite); - ICHECK(opt_composite.has_value()); + TVM_FFI_ICHECK(opt_composite.has_value()); std::string name = opt_composite.value(); std::shared_ptr node; @@ -187,10 +187,10 @@ class OpenCLMLJSONSerializer : public JSONSerializer { CompositeConvNode nodes{}; const auto* fn_var = cn->op.as(); - ICHECK(fn_var); + TVM_FFI_ICHECK(fn_var); const auto fn = Downcast(bindings_[ffi::GetRef(fn_var)]); auto opt_composite = fn->GetAttr(attr::kComposite); - ICHECK(opt_composite.has_value()); + TVM_FFI_ICHECK(opt_composite.has_value()); nodes.pad = backend::TryGetOpInFunction(fn, "relax.nn.pad"); nodes.conv = backend::TryGetOpInFunction(fn, "relax.nn.conv2d"); @@ -198,7 +198,7 @@ class OpenCLMLJSONSerializer : public JSONSerializer { if (!nodes.conv) { nodes.conv = backend::TryGetOpInFunction(fn, "relax.nn.conv2d_transpose"); } - ICHECK(nodes.conv) << "No Convolution op found in composite function"; + TVM_FFI_ICHECK(nodes.conv) << "No Convolution op found in composite function"; nodes.bn = backend::TryGetOpInFunction(fn, "relax.nn.batch_norm"); nodes.bias = backend::TryGetOpInFunction(fn, "relax.add"); nodes.activation = backend::TryGetOpInFunction(fn, "relax.nn.relu"); @@ -216,10 +216,10 @@ class OpenCLMLJSONSerializer : public JSONSerializer { CompositeConvNode nodes = UnpackCompositeConvolution(cn); const auto* fn_var = cn->op.as(); - ICHECK(fn_var); + TVM_FFI_ICHECK(fn_var); const auto fn = Downcast(bindings_[ffi::GetRef(fn_var)]); auto opt_composite = fn->GetAttr(attr::kComposite); - ICHECK(opt_composite.has_value()); + TVM_FFI_ICHECK(opt_composite.has_value()); std::string name = opt_composite.value(); std::vector inputs; @@ -251,7 +251,7 @@ class OpenCLMLJSONSerializer : public JSONSerializer { // Override attributes if (nodes.pad) { const auto* pad_attr = nodes.pad->attrs.as(); - ICHECK(pad_attr); + TVM_FFI_ICHECK(pad_attr); auto p = pad_attr->pad_width; // Pad layout for TVM: dimension wise pre and post padding. // CLML takes dimension wise pre-padding followed by dimension wise post-padding for W, H. diff --git a/src/relax/backend/contrib/codegen_c/codegen_c.h b/src/relax/backend/contrib/codegen_c/codegen_c.h index 3c6469423890..03eb51463409 100644 --- a/src/relax/backend/contrib/codegen_c/codegen_c.h +++ b/src/relax/backend/contrib/codegen_c/codegen_c.h @@ -72,7 +72,7 @@ class CodegenCBase { * \brief Exit a scope. */ void ExitScope() { - ICHECK_GE(indent_, 2U) << "Wrong ident found."; + TVM_FFI_ICHECK_GE(indent_, 2U) << "Wrong ident found."; indent_ -= 2; } @@ -334,7 +334,7 @@ class CodegenCBase { */ std::string GetDtypeString(const Var& var) { auto tsinfo = var->struct_info_.as(); - ICHECK(tsinfo) << "Expect TensorStructInfoNode"; + TVM_FFI_ICHECK(tsinfo) << "Expect TensorStructInfoNode"; return GetDtypeString(tsinfo); } @@ -362,7 +362,7 @@ class CodegenCBase { } else if (runtime::TypeMatch(tsinfo->dtype, kDLUInt, 8)) { dtype = "uint8_t"; } else { - LOG(FATAL) << "Unsupported dtype " << tsinfo->dtype; + TVM_FFI_THROW(InternalError) << "Unsupported dtype " << tsinfo->dtype; } return dtype; @@ -377,7 +377,7 @@ class CodegenCBase { */ std::string CreateInitChecker(const std::string& symbol) const { std::ostringstream oss; - oss << "ICHECK(!" << symbol + oss << "TVM_FFI_ICHECK(!" << symbol << "_consts.empty()) << \"C source module hasn't been initialized.\";\n"; return oss.str(); } diff --git a/src/relax/backend/contrib/codegen_json/codegen_json.h b/src/relax/backend/contrib/codegen_json/codegen_json.h index c8b7d464c247..d70e6db43fde 100644 --- a/src/relax/backend/contrib/codegen_json/codegen_json.h +++ b/src/relax/backend/contrib/codegen_json/codegen_json.h @@ -151,7 +151,7 @@ class OpAttrExtractor { if (auto opt_str = (*an)[i].as()) { attr.push_back(*opt_str); } else { - LOG(FATAL) << "Not supported type: " << (*an)[i].GetTypeKey(); + TVM_FFI_THROW(InternalError) << "Not supported type: " << (*an)[i].GetTypeKey(); } } SetNodeAttr(key, std::move(attr)); @@ -165,7 +165,7 @@ class OpAttrExtractor { } else if (const auto opt_str = (*value).as()) { SetNodeAttr(key, *opt_str); } else { - LOG(FATAL) << "Not yet supported type: " << (*value).GetTypeKey(); + TVM_FFI_THROW(InternalError) << "Not yet supported type: " << (*value).GetTypeKey(); } } @@ -178,7 +178,7 @@ class OpAttrExtractor { private: void VisitObjectFields(Object* obj) { const TVMFFITypeInfo* tinfo = TVMFFIGetTypeInfo(obj->type_index()); - ICHECK(tinfo->metadata != nullptr) + TVM_FFI_ICHECK(tinfo->metadata != nullptr) << "Object `" << obj->GetTypeKey() << "` misses reflection registration and do not support serialization"; ffi::reflection::ForEachFieldInfo(tinfo, [&](const TVMFFIFieldInfo* field_info) { @@ -275,9 +275,9 @@ class JSONSerializer : public relax::MemoizedExprTranslator { if (const auto* tuple_sinfo = struct_info.as()) { for (size_t i = 0; i < tuple_sinfo->fields.size(); ++i) { const auto* tensor_sinfo = tuple_sinfo->fields[i].as(); - ICHECK(tensor_sinfo) << "Expect TensorStructInfo, but received: ." - << tuple_sinfo->fields[i]->GetTypeKey(); - ICHECK(tensor_sinfo->shape.defined()) << "Expect shape to be defined."; + TVM_FFI_ICHECK(tensor_sinfo) + << "Expect TensorStructInfo, but received: ." << tuple_sinfo->fields[i]->GetTypeKey(); + TVM_FFI_ICHECK(tensor_sinfo->shape.defined()) << "Expect shape to be defined."; ShapeExpr output_shape = Downcast(tensor_sinfo->shape.value()); ret.push_back(JSONGraphNodeEntry(node_id, i)); shape.emplace_back(GetIntShape(output_shape->values)); @@ -286,9 +286,9 @@ class JSONSerializer : public relax::MemoizedExprTranslator { node->SetNumOutput(tuple_sinfo->fields.size()); } else { const auto* tensor_sinfo = struct_info.as(); - ICHECK(tensor_sinfo) << "Expect TensorStructInfo, but received: " - << struct_info->GetTypeKey(); - ICHECK(tensor_sinfo->shape.defined()) << "Expect shape to be defined."; + TVM_FFI_ICHECK(tensor_sinfo) + << "Expect TensorStructInfo, but received: " << struct_info->GetTypeKey(); + TVM_FFI_ICHECK(tensor_sinfo->shape.defined()) << "Expect shape to be defined."; ShapeExpr output_shape = Downcast(tensor_sinfo->shape.value()); shape.emplace_back(GetIntShape(output_shape->values)); @@ -312,15 +312,15 @@ class JSONSerializer : public relax::MemoizedExprTranslator { const Object* call_attr = cn->attrs.get(); extractor.Extract(const_cast(call_attr)); } else if (const auto* fn = cn->op.as()) { - ICHECK(false); + TVM_FFI_ICHECK(false); auto pattern = fn->GetAttr(attr::kPartitionedFromPattern); - ICHECK(pattern.has_value()); + TVM_FFI_ICHECK(pattern.has_value()); node->SetAttr("PartitionedFromPattern", pattern.value()); } } NodeEntries VisitBinding_(const MatchCastNode* binding) { - LOG(FATAL) << "JSON runtime currently doesn't match cast\n"; + TVM_FFI_THROW(InternalError) << "JSON runtime currently doesn't match cast\n"; return {}; } @@ -333,7 +333,7 @@ class JSONSerializer : public relax::MemoizedExprTranslator { auto from_b = VisitBinding_(node); nodes.insert(nodes.end(), from_b.begin(), from_b.end()); } else { - LOG(FATAL) << "TypeError: Invalid type: " << binding->GetTypeKey(); + TVM_FFI_THROW(TypeError) << "Invalid type: " << binding->GetTypeKey(); } return nodes; } @@ -347,7 +347,7 @@ class JSONSerializer : public relax::MemoizedExprTranslator { auto from_bb = VisitBindingBlock_(node); nodes.insert(nodes.end(), from_bb.begin(), from_bb.end()); } else { - LOG(FATAL) << "TypeError: Invalid type: " << block->GetTypeKey(); + TVM_FFI_THROW(TypeError) << "Invalid type: " << block->GetTypeKey(); } return nodes; } @@ -381,13 +381,13 @@ class JSONSerializer : public relax::MemoizedExprTranslator { } NodeEntries VisitExprDefault_(const Object* op) { - LOG(FATAL) << "JSON runtime currently doesn't support " << op->GetTypeKey(); + TVM_FFI_THROW(InternalError) << "JSON runtime currently doesn't support " << op->GetTypeKey(); return {}; } NodeEntries VisitExpr_(const ConstantNode* cn) { auto name = constant_names_.find(ffi::GetRef(cn)); - ICHECK(name != constant_names_.end()) + TVM_FFI_ICHECK(name != constant_names_.end()) << "Cannot find the name of the constant: " << ffi::GetRef(cn); constants_used_.push_back((*name).second); auto node = std::make_shared((*name).second, "const" /* op_type_ */); @@ -410,10 +410,11 @@ class JSONSerializer : public relax::MemoizedExprTranslator { name = op_node->name; } else if (const auto* fn = cn->op.as()) { auto comp = fn->GetAttr(attr::kComposite); - ICHECK(comp.has_value()) << "JSON runtime only supports composite functions."; + TVM_FFI_ICHECK(comp.has_value()) << "JSON runtime only supports composite functions."; name = comp.value(); } else { - LOG(FATAL) << "JSON runtime does not support calls to " << cn->op->GetTypeKey(); + TVM_FFI_THROW(InternalError) + << "JSON runtime does not support calls to " << cn->op->GetTypeKey(); } // TODO(@sunggg): Revisit when we have op naming convention. @@ -438,7 +439,7 @@ class JSONSerializer : public relax::MemoizedExprTranslator { } NodeEntries VisitExpr_(const FunctionNode* fn) { - ICHECK(fn->GetAttr(attr::kComposite).has_value()) + TVM_FFI_ICHECK(fn->GetAttr(attr::kComposite).has_value()) << "JSON runtime only supports composite functions"; // FunctionNode should be handled by the caller. diff --git a/src/relax/backend/contrib/cublas/codegen.cc b/src/relax/backend/contrib/cublas/codegen.cc index 424c0b9d4d37..7e750e2ec861 100644 --- a/src/relax/backend/contrib/cublas/codegen.cc +++ b/src/relax/backend/contrib/cublas/codegen.cc @@ -48,12 +48,12 @@ class CublasJSONSerializer : public JSONSerializer { NodeEntries VisitExpr_(const CallNode* call_node) final { const auto* fn_var = call_node->op.as(); - ICHECK(fn_var); + TVM_FFI_ICHECK(fn_var); const auto fn = Downcast(bindings_[ffi::GetRef(fn_var)]); - ICHECK(fn.defined()) << "Expects the callee to be a function."; + TVM_FFI_ICHECK(fn.defined()) << "Expects the callee to be a function."; auto composite_opt = fn->GetAttr(attr::kComposite); - ICHECK(composite_opt.has_value()) << "Only composite functions are supported."; + TVM_FFI_ICHECK(composite_opt.has_value()) << "Only composite functions are supported."; std::string composite_name = composite_opt.value(); @@ -63,7 +63,7 @@ class CublasJSONSerializer : public JSONSerializer { inputs_tmp.insert(inputs_tmp.end(), res.begin(), res.end()); } - ICHECK(inputs_tmp.size() <= 4); + TVM_FFI_ICHECK(inputs_tmp.size() <= 4); NodeEntries inputs(inputs_tmp.size()); auto arg_idx = backend::ExtractArgIdx(composite_name, fn); @@ -88,7 +88,7 @@ class CublasJSONSerializer : public JSONSerializer { if (sinfo->dtype == DataType::Float(16)) { alpha = __gnu_h2f_ieee(static_cast(const_expr->data->data)[0]); } else { - ICHECK(sinfo->dtype == DataType::Float(32)); + TVM_FFI_ICHECK(sinfo->dtype == DataType::Float(32)); alpha = static_cast(const_expr->data->data)[0]; } diff --git a/src/relax/backend/contrib/cudnn/codegen.cc b/src/relax/backend/contrib/cudnn/codegen.cc index 9f6d9b6f45c0..e1497610d150 100644 --- a/src/relax/backend/contrib/cudnn/codegen.cc +++ b/src/relax/backend/contrib/cudnn/codegen.cc @@ -47,12 +47,12 @@ class cuDNNJSONSerializer : public JSONSerializer { NodeEntries VisitExpr_(const CallNode* call_node) final { const auto* fn_var = call_node->op.as(); - ICHECK(fn_var); + TVM_FFI_ICHECK(fn_var); const auto fn = Downcast(bindings_[ffi::GetRef(fn_var)]); - ICHECK(fn.defined()) << "Expects the callee to be a function."; + TVM_FFI_ICHECK(fn.defined()) << "Expects the callee to be a function."; auto composite_opt = fn->GetAttr(attr::kComposite); - ICHECK(composite_opt.has_value()) << "Only composite functions are supported."; + TVM_FFI_ICHECK(composite_opt.has_value()) << "Only composite functions are supported."; std::string composite_name = composite_opt.value(); @@ -61,7 +61,7 @@ class cuDNNJSONSerializer : public JSONSerializer { } else if (composite_name.find("cudnn.attention") != std::string::npos) { return HandleAttention(call_node, fn, composite_name); } else { - LOG(FATAL) << "Unsupported composite function: " << composite_name; + TVM_FFI_THROW(InternalError) << "Unsupported composite function: " << composite_name; } } @@ -73,7 +73,7 @@ class cuDNNJSONSerializer : public JSONSerializer { inputs_tmp.insert(inputs_tmp.end(), res.begin(), res.end()); } - ICHECK(inputs_tmp.size() <= 3); + TVM_FFI_ICHECK(inputs_tmp.size() <= 3); NodeEntries inputs(inputs_tmp.size()); auto arg_idx = backend::ExtractArgIdx(composite_name, fn); @@ -100,7 +100,7 @@ class cuDNNJSONSerializer : public JSONSerializer { auto res = VisitExpr(arg); inputs.insert(inputs.end(), res.begin(), res.end()); } - ICHECK_EQ(inputs.size(), 2); + TVM_FFI_ICHECK_EQ(inputs.size(), 2); auto node = std::make_shared(composite_name, /* name_ */ "kernel", /* op_type_ */ inputs, 1 /* num_outputs_ */); diff --git a/src/relax/backend/contrib/cutlass/codegen.cc b/src/relax/backend/contrib/cutlass/codegen.cc index 69da3d6058ed..bc24fad3b9b1 100644 --- a/src/relax/backend/contrib/cutlass/codegen.cc +++ b/src/relax/backend/contrib/cutlass/codegen.cc @@ -56,7 +56,7 @@ std::string EmitSignature(const std::vector& out, const std::string& fun } ffi::Module Finalize(const std::string& code, const ffi::Array& func_names) { - ICHECK(!func_names.empty()) + TVM_FFI_ICHECK(!func_names.empty()) << "Should only create CUTLASS CSourceModule if there is at least one CUTLASS partition"; std::ostringstream default_headers; @@ -116,7 +116,7 @@ GenerateBodyOutput GenerateBody(const std::string& func_name, const std::string& const ffi::Array& func_args, const ffi::Map& attrs, int* buf_idx) { // Make function call with input buffers when visiting arguements - ICHECK_GT(func_args.size(), 0); + TVM_FFI_ICHECK_GT(func_args.size(), 0); std::ostringstream decl_stream; decl_stream << "(" << func_args[0]; for (size_t i = 1; i < func_args.size(); ++i) { @@ -168,7 +168,7 @@ class CodegenCutlass : public relax::MemoizedExprTranslator, } else if (const auto* shape_sinfo = sinfo.as()) { arg_types.emplace_back(backend::DType2String(shape_sinfo->values.value()[0]->dtype)); } else { - LOG(FATAL) << "Unimplemented"; + TVM_FFI_THROW(InternalError) << "Unimplemented"; } arg_names.push_back(var_name_map_.at(arg.get())); } @@ -200,17 +200,17 @@ class CodegenCutlass : public relax::MemoizedExprTranslator, OutputType VisitExpr_(const VarNode* node) final { Output output; auto it = var_name_map_.find(node); - ICHECK(it != var_name_map_.end()); + TVM_FFI_ICHECK(it != var_name_map_.end()); output.name = it->second; return {output}; } OutputType VisitExpr_(const CallNode* call) final { const auto* fn_var = call->op.as(); - ICHECK(fn_var); + TVM_FFI_ICHECK(fn_var); const auto func = Downcast(bindings_[ffi::GetRef(fn_var)]); const auto pattern_name_opt = func->GetAttr(attr::kComposite); - ICHECK(pattern_name_opt) << "Only composite function is supported for CUTLASS."; + TVM_FFI_ICHECK(pattern_name_opt) << "Only composite function is supported for CUTLASS."; auto ret = GenerateBody(call, pattern_name_opt.value(), func->attrs->dict); ext_func_body_.push_back(ret.decl); headers_ = ret.headers; @@ -218,7 +218,7 @@ class CodegenCutlass : public relax::MemoizedExprTranslator, } OutputType VisitExpr_(const FunctionNode* fn) final { - ICHECK(fn->GetAttr(attr::kComposite).has_value()) + TVM_FFI_ICHECK(fn->GetAttr(attr::kComposite).has_value()) << "JSON runtime only supports composite functions"; // FunctionNode should be handled by the caller. return {}; @@ -230,7 +230,7 @@ class CodegenCutlass : public relax::MemoizedExprTranslator, auto from_b = VisitBinding_(node); outputs.insert(outputs.end(), from_b.begin(), from_b.end()); } else { - LOG(FATAL) << "Unimplemented type: " << binding->GetTypeKey(); + TVM_FFI_THROW(InternalError) << "Unimplemented type: " << binding->GetTypeKey(); } return outputs; } @@ -244,7 +244,7 @@ class CodegenCutlass : public relax::MemoizedExprTranslator, auto from_bb = VisitBindingBlock_(node); outputs.insert(outputs.end(), from_bb.begin(), from_bb.end()); } else { - LOG(FATAL) << "Unimplemented type: " << block->GetTypeKey(); + TVM_FFI_THROW(InternalError) << "Unimplemented type: " << block->GetTypeKey(); } return outputs; } @@ -301,7 +301,7 @@ class CodegenCutlass : public relax::MemoizedExprTranslator, if (const auto* tensor_sinfo = struct_info.as()) { out_types.emplace_back(backend::DType2String(tensor_sinfo->dtype)); } else { - LOG(FATAL) << "Unimplemented sinfo type: " << struct_info; + TVM_FFI_THROW(InternalError) << "Unimplemented sinfo type: " << struct_info; } return contrib::GenerateBody(func_name, ext_func_id_, out_types, func_args, attrs, &buf_idx_); @@ -353,7 +353,7 @@ class CutlassModuleCodegen { private: std::pair> GenCutlassFunc( const Function& function, const ffi::Map& options) { - ICHECK(function.defined()) << "Input error: expect a Relax function."; + TVM_FFI_ICHECK(function.defined()) << "Input error: expect a Relax function."; auto sid = GetExtSymbol(function); func_names_.push_back(sid); @@ -376,7 +376,7 @@ ffi::Array CUTLASSCompiler(ffi::Array functions, ffi::Map options, ffi::Map /*unused*/) { const auto tune_func = tvm::ffi::Function::GetGlobal("contrib.cutlass.tune_relax_function"); - ICHECK(tune_func.has_value()) + TVM_FFI_ICHECK(tune_func.has_value()) << "The packed function contrib.cutlass.tune_relax_function not found, " "please import tvm.contrib.cutlass.build"; @@ -384,8 +384,9 @@ ffi::Array CUTLASSCompiler(ffi::Array functions, auto source_mod = CutlassModuleCodegen().CreateCSourceModule(annotated_functions, options); const auto pf = tvm::ffi::Function::GetGlobal("contrib.cutlass.compile"); - ICHECK(pf.has_value()) << "The packed function contrib.cutlass.compile not found, please import " - "tvm.contrib.cutlass.build"; + TVM_FFI_ICHECK(pf.has_value()) + << "The packed function contrib.cutlass.compile not found, please import " + "tvm.contrib.cutlass.build"; ffi::Module cutlass_mod = (*pf)(source_mod, options).cast(); return {cutlass_mod}; diff --git a/src/relax/backend/contrib/dnnl/codegen.cc b/src/relax/backend/contrib/dnnl/codegen.cc index e903ed885296..5c92ab4de2a5 100644 --- a/src/relax/backend/contrib/dnnl/codegen.cc +++ b/src/relax/backend/contrib/dnnl/codegen.cc @@ -47,12 +47,12 @@ class DNNLJSONSerializer : public JSONSerializer { NodeEntries VisitExpr_(const CallNode* call_node) final { const auto* fn_var = call_node->op.as(); - ICHECK(fn_var); + TVM_FFI_ICHECK(fn_var); const auto fn = Downcast(bindings_[ffi::GetRef(fn_var)]); - ICHECK(fn.defined()) << "Expects the callee to be a function."; + TVM_FFI_ICHECK(fn.defined()) << "Expects the callee to be a function."; auto composite_opt = fn->GetAttr(attr::kComposite); - ICHECK(composite_opt.has_value()) << "Only composite functions are supported."; + TVM_FFI_ICHECK(composite_opt.has_value()) << "Only composite functions are supported."; std::string composite_name = composite_opt.value(); @@ -69,7 +69,7 @@ class DNNLJSONSerializer : public JSONSerializer { if (composite_name.find("conv2d") != std::string::npos) { root_call = backend::GetOpInFunction(fn, "relax.nn.conv2d"); } else { - LOG(FATAL) << "Unimplemented pattern: " << composite_name; + TVM_FFI_THROW(InternalError) << "Unimplemented pattern: " << composite_name; } SetCallNodeAttribute(node, root_call); diff --git a/src/relax/backend/contrib/hipblas/codegen.cc b/src/relax/backend/contrib/hipblas/codegen.cc index 09a0f0026789..573ce686e56b 100644 --- a/src/relax/backend/contrib/hipblas/codegen.cc +++ b/src/relax/backend/contrib/hipblas/codegen.cc @@ -48,12 +48,12 @@ class HipblasJSONSerializer : public JSONSerializer { NodeEntries VisitExpr_(const CallNode* call_node) final { const auto* fn_var = call_node->op.as(); - ICHECK(fn_var); + TVM_FFI_ICHECK(fn_var); const auto fn = Downcast(bindings_[ffi::GetRef(fn_var)]); - ICHECK(fn.defined()) << "Expects the callee to be a function."; + TVM_FFI_ICHECK(fn.defined()) << "Expects the callee to be a function."; auto composite_opt = fn->GetAttr(attr::kComposite); - ICHECK(composite_opt.has_value()) << "Only composite functions are supported."; + TVM_FFI_ICHECK(composite_opt.has_value()) << "Only composite functions are supported."; std::string composite_name = composite_opt.value(); @@ -63,7 +63,7 @@ class HipblasJSONSerializer : public JSONSerializer { inputs_tmp.insert(inputs_tmp.end(), res.begin(), res.end()); } - ICHECK(inputs_tmp.size() <= 3); + TVM_FFI_ICHECK(inputs_tmp.size() <= 3); NodeEntries inputs(inputs_tmp.size()); auto arg_idx = backend::ExtractArgIdx(composite_name, fn); diff --git a/src/relax/backend/contrib/nnapi/codegen.cc b/src/relax/backend/contrib/nnapi/codegen.cc index 888feee15041..afc87f16b07a 100644 --- a/src/relax/backend/contrib/nnapi/codegen.cc +++ b/src/relax/backend/contrib/nnapi/codegen.cc @@ -53,7 +53,7 @@ class CollectFromCompositeFunctionBody : public ExprVisitor { void SetPermuteDimsAttribute(const CallNode* call_node) { const auto* permute_dims_attr = call_node->attrs.as(); - ICHECK(permute_dims_attr); + TVM_FFI_ICHECK(permute_dims_attr); if (permute_dims_attr->axes) { ffi::Array axes; for (auto axis : permute_dims_attr->axes.value()) { @@ -65,14 +65,14 @@ class CollectFromCompositeFunctionBody : public ExprVisitor { void SetAstypeAttribute(const CallNode* call_node) { const auto* astype_attrs = call_node->attrs.as(); - ICHECK(astype_attrs); + TVM_FFI_ICHECK(astype_attrs); node_->SetAttr("astype_dtype", ffi::String(runtime::DLDataTypeToString(astype_attrs->dtype))); } void SetMeanAttribute(const CallNode* call_node) { const auto* mean_attrs = call_node->attrs.as(); - ICHECK(mean_attrs); - ICHECK(mean_attrs->axis.defined()); + TVM_FFI_ICHECK(mean_attrs); + TVM_FFI_ICHECK(mean_attrs->axis.defined()); { ffi::Array axis; @@ -86,7 +86,7 @@ class CollectFromCompositeFunctionBody : public ExprVisitor { void SetConv2dAttribute(const CallNode* call_node) { const auto* conv2d_attr = call_node->attrs.as(); - ICHECK(conv2d_attr) << "didn't catch attributes"; + TVM_FFI_ICHECK(conv2d_attr) << "didn't catch attributes"; ffi::Array strides; if (!conv2d_attr->strides.empty()) { @@ -110,7 +110,7 @@ class CollectFromCompositeFunctionBody : public ExprVisitor { void SetMaxPool2dAttribute(const CallNode* call_node) { const auto* max_pool_2d_attr = call_node->attrs.as(); - ICHECK(max_pool_2d_attr) << "didn't catch attributes"; + TVM_FFI_ICHECK(max_pool_2d_attr) << "didn't catch attributes"; ffi::Array strides; if (!max_pool_2d_attr->strides.empty()) { @@ -150,12 +150,12 @@ class NNAPIJSONSerializer : public JSONSerializer { std::vector VisitExpr_(const CallNode* call_node) final { const auto* fn_var = call_node->op.as(); - ICHECK(fn_var); + TVM_FFI_ICHECK(fn_var); const auto fn = Downcast(bindings_[ffi::GetRef(fn_var)]); - ICHECK(fn.defined()) << "Expects the callee to be a function."; + TVM_FFI_ICHECK(fn.defined()) << "Expects the callee to be a function."; auto composite_opt = fn->GetAttr(attr::kComposite); - ICHECK(composite_opt.has_value()) << "Only composite functions are supported."; + TVM_FFI_ICHECK(composite_opt.has_value()) << "Only composite functions are supported."; std::string composite_name = composite_opt.value(); @@ -184,7 +184,7 @@ class NNAPIJSONSerializer : public JSONSerializer { void CollectFromCompositeFunctionBody::VisitExpr_(const CallNode* call_node) { const auto* op_node = call_node->op.as(); - ICHECK(op_node != nullptr); + TVM_FFI_ICHECK(op_node != nullptr); std::string name = op_node->name; if (name == "relax.permute_dims") { SetPermuteDimsAttribute(call_node); diff --git a/src/relax/backend/contrib/tensorrt/codegen.cc b/src/relax/backend/contrib/tensorrt/codegen.cc index cafc5f2e6330..9fcd4f43be8e 100644 --- a/src/relax/backend/contrib/tensorrt/codegen.cc +++ b/src/relax/backend/contrib/tensorrt/codegen.cc @@ -135,11 +135,11 @@ class TensorRTJSONSerializer : public JSONSerializer { std::vector VisitExpr_(const CallNode* call_node) final { // The call must be to an inline "Composite" function const auto* fn_var = call_node->op.as(); - ICHECK(fn_var); + TVM_FFI_ICHECK(fn_var); const auto fn = Downcast(bindings_[ffi::GetRef(fn_var)]); auto opt_composite = fn->GetAttr(attr::kComposite); - ICHECK(opt_composite.has_value()); + TVM_FFI_ICHECK(opt_composite.has_value()); std::string name = opt_composite.value(); // Collect the constants and attributes of all operator calls inside the composite body. @@ -180,7 +180,7 @@ class TensorRTJSONSerializer : public JSONSerializer { if (!cfg.defined()) { cfg = AttrsWithDefaultValues(); } - ICHECK_EQ(cfg.value()->tensorrt_version.size(), 3); + TVM_FFI_ICHECK_EQ(cfg.value()->tensorrt_version.size(), 3); ffi::Array tensorrt_version = {cfg.value()->tensorrt_version[0].IntValue(), cfg.value()->tensorrt_version[1].IntValue(), cfg.value()->tensorrt_version[2].IntValue()}; diff --git a/src/relax/backend/contrib/utils.cc b/src/relax/backend/contrib/utils.cc index 1840986c019d..ab1cdf0e4ae7 100644 --- a/src/relax/backend/contrib/utils.cc +++ b/src/relax/backend/contrib/utils.cc @@ -34,15 +34,14 @@ namespace backend { ffi::Map ExtractArgIdx(ffi::String pattern_name, Function f) { ffi::Map arg_idx; auto pattern = backend::GetPattern(pattern_name); - ICHECK(pattern) << "Unsupported op_type " << pattern_name; + TVM_FFI_ICHECK(pattern) << "Unsupported op_type " << pattern_name; auto bindings = AnalyzeVar2Value(f); auto inner_body = f->body->body; auto matched_expr = relax::ExtractMatchedExpr(pattern.value()->pattern, inner_body, bindings); - ICHECK(matched_expr) << "ValueError: " - << "For named pattern \"" << pattern_name - << "\", expected to find a match for " << pattern.value()->pattern - << ". However, the function did not include this pattern " << f; + TVM_FFI_CHECK(matched_expr, ValueError) + << "For named pattern \"" << pattern_name << "\", expected to find a match for " + << pattern.value()->pattern << ". However, the function did not include this pattern " << f; auto find_index = [](const ffi::Array& params, Var v) -> std::optional { for (size_t i = 0; i < params.size(); ++i) { diff --git a/src/relax/backend/contrib/utils.h b/src/relax/backend/contrib/utils.h index e1bcfd0aee1e..e13ab6c65e8d 100644 --- a/src/relax/backend/contrib/utils.h +++ b/src/relax/backend/contrib/utils.h @@ -101,7 +101,7 @@ inline const CallNode* TryGetOpInFunction(Function f, const std::string& op_name */ inline const CallNode* GetOpInFunction(Function f, const std::string& op_name) { const CallNode* op = TryGetOpInFunction(f, op_name); - ICHECK(op) << op_name << " not found in the function:\n" << f; + TVM_FFI_ICHECK(op) << op_name << " not found in the function:\n" << f; return op; } diff --git a/src/relax/backend/task_extraction.cc b/src/relax/backend/task_extraction.cc index 154ca330981f..879d49a22355 100644 --- a/src/relax/backend/task_extraction.cc +++ b/src/relax/backend/task_extraction.cc @@ -90,7 +90,7 @@ class TaskExtractor : public ExprVisitor { mod_eq_(ModuleEquality::Create(mod_eq_name)), func2task_(/*bucket_count*/ 0, ModuleHash(*mod_eq_), ModuleEqual(*mod_eq_)) { normalize_mod_func_ = tvm::ffi::Function::GetGlobal("tvm.s_tir.meta_schedule.normalize_mod"); - ICHECK(normalize_mod_func_.has_value()) << "Normalization function is not found."; + TVM_FFI_ICHECK(normalize_mod_func_.has_value()) << "Normalization function is not found."; } void VisitExpr_(const CallNode* call) final { diff --git a/src/relax/backend/vm/codegen_vm.cc b/src/relax/backend/vm/codegen_vm.cc index e545737f90fb..8736d91af25c 100644 --- a/src/relax/backend/vm/codegen_vm.cc +++ b/src/relax/backend/vm/codegen_vm.cc @@ -84,8 +84,9 @@ class CodeGenVM : public ExprFunctor { void Codegen(const Function& func) { ffi::Optional gsymbol = func->GetAttr(tvm::attr::kGlobalSymbol); - ICHECK(gsymbol.has_value()) << "there should be no local functions in Relax VM codegen phase. " - "Did you forget to apply LambdaLift or AttachGlobalSymbol Pass?"; + TVM_FFI_ICHECK(gsymbol.has_value()) + << "there should be no local functions in Relax VM codegen phase. " + "Did you forget to apply LambdaLift or AttachGlobalSymbol Pass?"; ffi::Array param_names; for (Var param : func->params) { @@ -96,7 +97,7 @@ class CodeGenVM : public ExprFunctor { for (size_t i = 0; i < func->params.size(); ++i) { RegName r = NewRegister(); - ICHECK_EQ(r, static_cast(i)); + TVM_FFI_ICHECK_EQ(r, static_cast(i)); this->var_arg_map_.insert({func->params[i], Instruction::Arg::Register(r)}); } Instruction::Arg ret = ExprFunctor::VisitExpr(func->body); @@ -155,7 +156,8 @@ class CodeGenVM : public ExprFunctor { } else { // every "normal" operator is lowered to a global var in the IRModule. The Attrs for those // ops are handled in a pass when lowering them to TIR. - LOG(FATAL) << "CodeGenVM cannot handle this intrinsic now:\n" << call_node->op; + TVM_FFI_THROW(InternalError) << "CodeGenVM cannot handle this intrinsic now:\n" + << call_node->op; } } else { EmitNormalCall(call, dst_reg); @@ -210,7 +212,7 @@ class CodeGenVM : public ExprFunctor { Instruction::Arg VisitExpr_(const VarNode* op) final { Var var = ffi::GetRef(op); auto it = this->var_arg_map_.find(var); - ICHECK(it != this->var_arg_map_.end()) << "Var " << var << " is not defined"; + TVM_FFI_ICHECK(it != this->var_arg_map_.end()) << "Var " << var << " is not defined"; return it->second; } @@ -232,7 +234,8 @@ class CodeGenVM : public ExprFunctor { if (auto* int_value = e.as()) { shape.push_back(int_value->value); } else { - LOG(FATAL) << "Should only use constant shape after shape lowering: " << op->values; + TVM_FFI_THROW(InternalError) + << "Should only use constant shape after shape lowering: " << op->values; } } return builder_->ConvertConstant(ffi::Shape(shape)); @@ -244,9 +247,9 @@ class CodeGenVM : public ExprFunctor { } else if (auto* float_imm = op->value.as()) { return builder_->ConvertConstant(float_imm->value); } else { - LOG(FATAL) << "PrimValue should only contain constant after VMShapeLower, " - << "but received " << ffi::GetRef(op) << " with type " - << op->value->GetTypeKey(); + TVM_FFI_THROW(InternalError) + << "PrimValue should only contain constant after VMShapeLower, " + << "but received " << ffi::GetRef(op) << " with type " << op->value->GetTypeKey(); TVM_FFI_UNREACHABLE(); } } @@ -309,7 +312,7 @@ class CodeGenVM : public ExprFunctor { kind = VMFuncInfo::FuncKind::kPackedFunc; } // declare the function to be safe. - ICHECK(symbol.has_value()); + TVM_FFI_ICHECK(symbol.has_value()); builder_->DeclareFunction(symbol.value(), kind); return builder_->GetFunction(symbol.value()); } @@ -331,7 +334,7 @@ class CodeGenVM : public ExprFunctor { } void EmitAllocStorage(const Call& call_node, RegName dst_reg) { - ICHECK_EQ(call_node->args.size(), 4); + TVM_FFI_ICHECK_EQ(call_node->args.size(), 4); // Handle args of the call std::vector args; args.push_back(Instruction::Arg::Register(Instruction::kVMRegister)); @@ -343,7 +346,7 @@ class CodeGenVM : public ExprFunctor { } void EmitAllocTensor(const Call& call_node, RegName dst_reg) { - ICHECK_EQ(call_node->args.size(), 5); + TVM_FFI_ICHECK_EQ(call_node->args.size(), 5); std::vector args; for (int i = 0; i < 4; ++i) { args.push_back(this->VisitExpr(call_node->args[i])); @@ -361,9 +364,9 @@ class CodeGenVM : public ExprFunctor { } RegName EmitKillObject(const Call& call_node) { - ICHECK_EQ(call_node->args.size(), 1); + TVM_FFI_ICHECK_EQ(call_node->args.size(), 1); Instruction::Arg arg = this->VisitExpr(call_node->args[0]); - ICHECK(arg.kind() == Instruction::ArgKind::kRegister) + TVM_FFI_ICHECK(arg.kind() == Instruction::ArgKind::kRegister) << "Expected the object to be killed to be stored in a register, " << "but argument " << call_node->args[0] << " produced VM instruction of type " << arg.kind(); @@ -473,7 +476,8 @@ void LinkModules(ObjectPtr exec, const ffi::Map(const Expr&)> { } void EmitStmt(tir::Stmt stmt) { - ICHECK(!stmt_stack_.empty()); + TVM_FFI_ICHECK(!stmt_stack_.empty()); stmt_stack_.back().emplace_back(stmt); } @@ -130,7 +130,7 @@ class CodeGenVMTIR : public ExprFunctor(const Expr&)> { void EmitCallCPacked(const tir::PrimFunc& prim_func, const ffi::Array& args, int64_t dst_anylist_slot = -1) { ffi::Optional gsymbol = prim_func->GetAttr(tvm::attr::kGlobalSymbol); - ICHECK(gsymbol.has_value()) << "All functions must have global symbol at this phase"; + TVM_FFI_ICHECK(gsymbol.has_value()) << "All functions must have global symbol at this phase"; ffi::Array all_args; // negative index indicate return value can be discarded, emit call_packed if (dst_anylist_slot >= 0) { @@ -151,8 +151,9 @@ class CodeGenVMTIR : public ExprFunctor(const Expr&)> { tir::PrimFunc Codegen(const Function& func) { ffi::Optional gsymbol = func->GetAttr(tvm::attr::kGlobalSymbol); - ICHECK(gsymbol.has_value()) << "there should be no local functions in Relax VM codegen phase. " - "Did you forget to apply LambdaLift or AttachGlobalSymbol Pass?"; + TVM_FFI_ICHECK(gsymbol.has_value()) + << "there should be no local functions in Relax VM codegen phase. " + "Did you forget to apply LambdaLift or AttachGlobalSymbol Pass?"; // initialize the state stmt_stack_ = {}; registers_num_ = 0; @@ -171,7 +172,7 @@ class CodeGenVMTIR : public ExprFunctor(const Expr&)> { for (size_t i = 0; i < func->params.size(); ++i) { int64_t r = NewRegister(); - ICHECK_EQ(static_cast(r), i); + TVM_FFI_ICHECK_EQ(static_cast(r), i); this->var_map_.insert({func->params[i], RegListGet(r)}); } size_t ret_reg = NewRegister(); @@ -243,7 +244,8 @@ class CodeGenVMTIR : public ExprFunctor(const Expr&)> { } else { // every "normal" operator is lowered to a global var in the IRModule. The Attrs for those // ops are handled in a pass when lowering them to TIR. - LOG(FATAL) << "CodeGenVMTIR cannot handle this intrinsic now:\n" << call_node->op; + TVM_FFI_THROW(InternalError) << "CodeGenVMTIR cannot handle this intrinsic now:\n" + << call_node->op; } } else { EmitNormalCall(call, dst_reg); @@ -278,7 +280,7 @@ class CodeGenVMTIR : public ExprFunctor(const Expr&)> { ffi::Optional VisitExpr_(const VarNode* op) final { Var var = ffi::GetRef(op); auto it = this->var_map_.find(var); - ICHECK(it != this->var_map_.end()) << "Var " << var << " is not defined"; + TVM_FFI_ICHECK(it != this->var_map_.end()) << "Var " << var << " is not defined"; return it->second; } @@ -292,7 +294,8 @@ class CodeGenVMTIR : public ExprFunctor(const Expr&)> { if (auto* int_value = e.as()) { shape.push_back(int_value->value); } else { - LOG(FATAL) << "Should only use constant shape after shape lowering: " << op->values; + TVM_FFI_THROW(InternalError) + << "Should only use constant shape after shape lowering: " << op->values; } } return ConstListGet(builder_->ConvertConstant(ffi::Shape(shape)).value()); @@ -382,7 +385,7 @@ class CodeGenVMTIR : public ExprFunctor(const Expr&)> { ffi::Optional VisitExpr_(const GlobalVarNode* op) final { VMFuncInfo::FuncKind kind; auto symbol = LookupFunction(ffi::GetRef(op), &kind); - ICHECK(symbol.has_value()); + TVM_FFI_ICHECK(symbol.has_value()); builder_->DeclareFunction(symbol.value(), kind); return FuncListGet(builder_->GetFunction(symbol.value()).value()); } @@ -403,7 +406,7 @@ class CodeGenVMTIR : public ExprFunctor(const Expr&)> { } void EmitAllocTensor(const Call& call_node, int64_t dst_reg) { - ICHECK_EQ(call_node->args.size(), 5); + TVM_FFI_ICHECK_EQ(call_node->args.size(), 5); ffi::Array args; for (int i = 0; i < 4; ++i) { args.push_back(this->VisitExpr(call_node->args[i]).value()); @@ -422,18 +425,18 @@ class CodeGenVMTIR : public ExprFunctor(const Expr&)> { } int64_t EmitKillObject(const Call& call_node) { - ICHECK_EQ(call_node->args.size(), 1); + TVM_FFI_ICHECK_EQ(call_node->args.size(), 1); PrimExpr arg = this->VisitExpr(call_node->args[0]).value(); // Check the arg is a register. const auto* tir_call = arg.as(); - ICHECK(tir_call != nullptr); - ICHECK(tir_call->op == tir::builtin::anylist_getitem()); - ICHECK(tir_call->args.size() == 2); - ICHECK(tir_call->args[0].same_as(reg_anylist_handle_)); + TVM_FFI_ICHECK(tir_call != nullptr); + TVM_FFI_ICHECK(tir_call->op == tir::builtin::anylist_getitem()); + TVM_FFI_ICHECK(tir_call->args.size() == 2); + TVM_FFI_ICHECK(tir_call->args[0].same_as(reg_anylist_handle_)); const auto* p_dst_reg = tir_call->args[1].as(); - ICHECK(p_dst_reg != nullptr); - ICHECK(p_dst_reg->dtype == DataType::Int(32)); + TVM_FFI_ICHECK(p_dst_reg != nullptr); + TVM_FFI_ICHECK(p_dst_reg->dtype == DataType::Int(32)); int64_t dst_reg = p_dst_reg->value; this->EmitCallPacked("vm.builtin.null_value", {}, dst_reg); @@ -445,7 +448,7 @@ class CodeGenVMTIR : public ExprFunctor(const Expr&)> { // if context is required, pass as first argument. args.push_back(ctx_ptr_); auto* func = call_node->args[0].as(); - ICHECK(func) << "CallBuiltin comes with extern func"; + TVM_FFI_ICHECK(func) << "CallBuiltin comes with extern func"; auto tuple_arg = Downcast(call_node->args[1]); diff --git a/src/relax/backend/vm/exec_builder.cc b/src/relax/backend/vm/exec_builder.cc index bcbff2ad84d0..ac394f93fb90 100644 --- a/src/relax/backend/vm/exec_builder.cc +++ b/src/relax/backend/vm/exec_builder.cc @@ -73,7 +73,7 @@ vm::Instruction::Arg ExecBuilderNode::ConvertConstant_(Any cvalue) { void ExecBuilderNode::DeclareFunction(const std::string& func_name, VMFuncInfo::FuncKind kind) { auto it = exec_->func_map.find(func_name); if (it != exec_->func_map.end()) { - ICHECK(kind == exec_->func_table[it->second].kind) + TVM_FFI_ICHECK(kind == exec_->func_table[it->second].kind) << "Function " << func_name << "already declared in a different kind"; return; } @@ -90,7 +90,7 @@ void ExecBuilderNode::DeclareFunction(const std::string& func_name, VMFuncInfo:: vm::Instruction::Arg ExecBuilderNode::GetFunction(const std::string& func_name) { auto it = exec_->func_map.find(func_name); - ICHECK(it != exec_->func_map.end()) << "Cannot find function " << func_name; + TVM_FFI_ICHECK(it != exec_->func_map.end()) << "Cannot find function " << func_name; return vm::Instruction::Arg::FuncIdx(it->second); } @@ -102,11 +102,11 @@ void ExecBuilderNode::EmitFunction(const std::string& func_name, int64_t num_inp this->DeclareFunction(func_name, kind); } auto& vmfunc = exec_->func_table.at(exec_->func_map.at(func_name)); - ICHECK_EQ(vmfunc.name, func_name); - ICHECK_EQ(vmfunc.num_args, -2) << "Function " << func_name << " already defined"; + TVM_FFI_ICHECK_EQ(vmfunc.name, func_name); + TVM_FFI_ICHECK_EQ(vmfunc.num_args, -2) << "Function " << func_name << " already defined"; vmfunc.num_args = num_inputs; if (param_names.defined()) { - ICHECK_EQ(num_inputs, param_names.value().size()) + TVM_FFI_ICHECK_EQ(num_inputs, param_names.value().size()) << "Function " << func_name << " defined with " << num_inputs << " arguments, " << "but the list of parameter names has " << param_names.value().size() << " names (" << param_names << ")"; @@ -124,9 +124,9 @@ void ExecBuilderNode::EmitFunction(const std::string& func_name, int64_t num_inp void ExecBuilderNode::EndFunction(const std::string& func_name) { auto it = exec_->func_map.find(func_name); - ICHECK(it != exec_->func_map.end()); + TVM_FFI_ICHECK(it != exec_->func_map.end()); VMFuncInfo& vmfunc = exec_->func_table.at(it->second); - ICHECK_EQ(vmfunc.end_instr, 0) << "EndFuncton can only be called once"; + TVM_FFI_ICHECK_EQ(vmfunc.end_instr, 0) << "EndFuncton can only be called once"; if (vmfunc.kind == vm::VMFuncInfo::FuncKind::kVMFunc) { vmfunc.end_instr = exec_->instr_offset.size(); @@ -135,7 +135,7 @@ void ExecBuilderNode::EndFunction(const std::string& func_name) { void ExecBuilderNode::EmitCall(vm::Instruction::Arg func, std::vector args, vm::RegName dst) { - ICHECK(func.kind() == vm::Instruction::ArgKind::kFuncIdx); + TVM_FFI_ICHECK(func.kind() == vm::Instruction::ArgKind::kFuncIdx); // store instruction exec_->instr_offset.push_back(exec_->instr_data.size()); exec_->instr_data.push_back(static_cast(Opcode::Call)); @@ -158,7 +158,7 @@ void ExecBuilderNode::EmitCall(const std::string& func, std::vectorinstr_offset.push_back(exec_->instr_data.size()); exec_->instr_data.push_back(static_cast(Opcode::Ret)); exec_->instr_data.push_back(result.value()); @@ -171,7 +171,7 @@ void ExecBuilderNode::EmitGoto(Index pc_offset) { } void ExecBuilderNode::EmitIf(vm::Instruction::Arg cond, vm::Index false_offset) { - ICHECK(cond.kind() == vm::Instruction::ArgKind::kRegister); + TVM_FFI_ICHECK(cond.kind() == vm::Instruction::ArgKind::kRegister); exec_->instr_offset.push_back(exec_->instr_data.size()); exec_->instr_data.push_back(static_cast(Opcode::If)); exec_->instr_data.push_back(cond.value()); @@ -182,7 +182,7 @@ void ExecBuilderNode::CheckExecutable() { for (auto it = exec_->func_table.cbegin(); it != exec_->func_table.cend(); ++it) { if (it->kind == VMFuncInfo::FuncKind::kPackedFunc) continue; if (it->kind == VMFuncInfo::FuncKind::kVMTIRFunc) { - ICHECK_GE(it->register_file_size, it->num_args + 1) + TVM_FFI_ICHECK_GE(it->register_file_size, it->num_args + 1) << "Function " << it->name << " do not meet register file constraint."; continue; } @@ -192,7 +192,7 @@ void ExecBuilderNode::CheckExecutable() { size_t start_instr = it->start_instr; size_t end_instr = it->end_instr; - CHECK_LT(start_instr, end_instr) + TVM_FFI_ICHECK_LT(start_instr, end_instr) << "Function " << it->name << " EndFunction has not be been called"; auto check_reg_defined = [&](Instruction::Arg arg) { @@ -201,23 +201,23 @@ void ExecBuilderNode::CheckExecutable() { if (arg.value() < num_inputs) return; if (dst_registers.find(arg.value()) == dst_registers.end()) { - LOG(FATAL) << "register r(" << arg.value() << ") in VM function \"" << it->name - << "\" is used as input while it is never defined" - << " as a destination. Dump:\n" - << exec_->AsText(); + TVM_FFI_THROW(InternalError) << "register r(" << arg.value() << ") in VM function \"" + << it->name << "\" is used as input while it is never defined" + << " as a destination. Dump:\n" + << exec_->AsText(); } }; auto check_const_defined = [&](Instruction::Arg arg) { if (arg.kind() != Instruction::ArgKind::kConstIdx) return; - CHECK_LT(arg.value(), exec_->constants.size()) + TVM_FFI_ICHECK_LT(arg.value(), exec_->constants.size()) << "Constant index " << arg.value() << " exceed size of constant pool. Dump:\n" << exec_->AsText(); }; auto check_func_defined = [&](Instruction::Arg arg) { if (arg.kind() != Instruction::ArgKind::kFuncIdx) return; - CHECK_LT(arg.value(), exec_->func_table.size()) + TVM_FFI_ICHECK_LT(arg.value(), exec_->func_table.size()) << "Func index " << arg.value() << " exceed size of fun_table. Dump:\n" << exec_->AsText(); }; @@ -244,17 +244,18 @@ void ExecBuilderNode::CheckExecutable() { break; } case Opcode::Goto: { - ICHECK_NE(instr.pc_offset, 0); + TVM_FFI_ICHECK_NE(instr.pc_offset, 0); break; } case Opcode::If: { - ICHECK_GT(instr.false_offset, 1); + TVM_FFI_ICHECK_GT(instr.false_offset, 1); check_reg_defined(Instruction::Arg::Register(instr.cond)); arg_registers.emplace(instr.cond); break; } default: - LOG(FATAL) << "should never hit this case: " << static_cast(instr.op); + TVM_FFI_THROW(InternalError) + << "should never hit this case: " << static_cast(instr.op); break; } } @@ -316,7 +317,8 @@ void ExecBuilderNode::Formalize() { break; } default: - LOG(FATAL) << "should never hit this case: " << static_cast(instr.op); + TVM_FFI_THROW(InternalError) + << "should never hit this case: " << static_cast(instr.op); break; } } diff --git a/src/relax/backend/vm/lower_runtime_builtin.cc b/src/relax/backend/vm/lower_runtime_builtin.cc index a1ffa4618423..ac30fbed1e16 100644 --- a/src/relax/backend/vm/lower_runtime_builtin.cc +++ b/src/relax/backend/vm/lower_runtime_builtin.cc @@ -62,10 +62,10 @@ class LowerRuntimeBuiltinMutator : public ExprMutator { } else if (call->op == invoke_closure_op_) { return InvokeClosure(call); } else if (call->op == alloc_tensor_op_) { - LOG(FATAL) << "VMBuiltinLower encountered " << call->op << " in expression " - << ffi::GetRef(call_node) << ". " - << "This operation should have been lowered earlier " - << "using the 'relax.transform.LowerAllocTensor' pass."; + TVM_FFI_THROW(InternalError) << "VMBuiltinLower encountered " << call->op << " in expression " + << ffi::GetRef(call_node) << ". " + << "This operation should have been lowered earlier " + << "using the 'relax.transform.LowerAllocTensor' pass."; } else if (call->op == mem_alloc_storage_op_) { return MakeMemAllocStorage(call); } else if (call->op == mem_alloc_tensor_op_) { @@ -102,14 +102,14 @@ class LowerRuntimeBuiltinMutator : public ExprMutator { } Expr MakeMemKillObject(const Call& call) { - ICHECK_EQ(call->args.size(), 1); + TVM_FFI_ICHECK_EQ(call->args.size(), 1); return Call(vm_kill_object_op_, {call->args[0]}, Attrs()); } Expr CallTIRDyn(const Call& call_node) { - ICHECK(call_node->args.size() == 2); - ICHECK(call_node->args[0]->IsInstance()); - ICHECK(call_node->args[1]->IsInstance()); + TVM_FFI_ICHECK(call_node->args.size() == 2); + TVM_FFI_ICHECK(call_node->args[0]->IsInstance()); + TVM_FFI_ICHECK(call_node->args[1]->IsInstance()); ffi::Array args; auto tir_args = Downcast(call_node->args[1]); @@ -121,12 +121,11 @@ class LowerRuntimeBuiltinMutator : public ExprMutator { } Expr Reshape(const Call& call_node) { - ICHECK(call_node->args.size() == 2); - ICHECK(call_node->struct_info_.defined()); + TVM_FFI_ICHECK(call_node->args.size() == 2); + TVM_FFI_ICHECK(call_node->struct_info_.defined()); auto arg = call_node->args[1]; - CHECK(arg->struct_info_->IsInstance()) - << "TypeError: " + TVM_FFI_CHECK(arg->struct_info_->IsInstance(), TypeError) << "VMBuiltinLower expects the shape arg of R.reshape " << "to be a ShapeExpr or VarNode bound to a ShapeExpr. " << "However, in expression " << call_node << ", the shape argument " << arg @@ -136,21 +135,21 @@ class LowerRuntimeBuiltinMutator : public ExprMutator { } Expr ShapeOf(const Call& call_node) { - ICHECK(call_node->args.size() == 1); - ICHECK(call_node->struct_info_.defined()); + TVM_FFI_ICHECK(call_node->args.size() == 1); + TVM_FFI_ICHECK(call_node->struct_info_.defined()); return Call(builtin_shape_of_, call_node->args, Attrs(), {GetStructInfo(call_node)}); } Expr TensorToShape(const Call& call_node) { - ICHECK(call_node->args.size() == 1); - ICHECK(call_node->struct_info_.defined()); + TVM_FFI_ICHECK(call_node->args.size() == 1); + TVM_FFI_ICHECK(call_node->struct_info_.defined()); return Call(builtin_tensor_to_shape_, call_node->args, Attrs(), {GetStructInfo(call_node)}); } Expr CallPyFunc(const Call& call_node) { - ICHECK(call_node->args.size() == 2); - ICHECK(call_node->struct_info_.defined()); + TVM_FFI_ICHECK(call_node->args.size() == 2); + TVM_FFI_ICHECK(call_node->struct_info_.defined()); // Create tuple with function name and arguments tuple ffi::Array tuple_fields; @@ -165,8 +164,8 @@ class LowerRuntimeBuiltinMutator : public ExprMutator { Expr ToDevice(const Call& call_node) { // TODO(yongwww): replace ToVDeviceAttrs with related Expr - ICHECK(call_node->args.size() == 1); - ICHECK(call_node->struct_info_.defined()); + TVM_FFI_ICHECK(call_node->args.size() == 1); + TVM_FFI_ICHECK(call_node->struct_info_.defined()); auto attrs = call_node->attrs.as(); ffi::Array args; args.push_back(call_node->args[0]); @@ -182,9 +181,9 @@ class LowerRuntimeBuiltinMutator : public ExprMutator { } Expr MakeClosure(const Call& call_node) { - ICHECK(call_node->args.size() == 2); - ICHECK(call_node->args[0]->IsInstance()); - ICHECK(call_node->args[1]->IsInstance()); + TVM_FFI_ICHECK(call_node->args.size() == 2); + TVM_FFI_ICHECK(call_node->args[0]->IsInstance()); + TVM_FFI_ICHECK(call_node->args[1]->IsInstance()); ffi::Array args; auto func = call_node->args[0]; @@ -199,9 +198,9 @@ class LowerRuntimeBuiltinMutator : public ExprMutator { } Expr InvokeClosure(const Call& call_node) { - ICHECK(call_node->args.size() == 2); - ICHECK(call_node->args[0]->IsInstance()); - ICHECK(call_node->args[1]->IsInstance()); + TVM_FFI_ICHECK(call_node->args.size() == 2); + TVM_FFI_ICHECK(call_node->args[0]->IsInstance()); + TVM_FFI_ICHECK(call_node->args[1]->IsInstance()); ffi::Array args; diff --git a/src/relax/backend/vm/vm_shape_lower.cc b/src/relax/backend/vm/vm_shape_lower.cc index bbc227d1d559..68c266c5dd11 100644 --- a/src/relax/backend/vm/vm_shape_lower.cc +++ b/src/relax/backend/vm/vm_shape_lower.cc @@ -308,7 +308,7 @@ class VMShapeLowerMutator ffi::Array dep_vars = tir::UndefinedVars(slot->expr); for (auto var : dep_vars) { auto it = slot_map_.find(var); - ICHECK(it != slot_map_.end()) + TVM_FFI_ICHECK(it != slot_map_.end()) << "Var " << var << "is not defined in the function but is referenced by " << slot->expr; auto* var_slot = it->second; @@ -348,8 +348,8 @@ class VMShapeLowerMutator // Expr mutation overloading. //------------------------------------------------------- Expr VisitExpr_(const FunctionNode* op) final { - LOG(FATAL) << "VMShapeLower do not work for local functions, make sure " - << " to run it after LambdaLift"; + TVM_FFI_THROW(InternalError) << "VMShapeLower do not work for local functions, make sure " + << " to run it after LambdaLift"; return ffi::GetRef(op); } @@ -361,9 +361,9 @@ class VMShapeLowerMutator PrimValue::Int64(int_expr->value)}; } else { auto it = slot_map_.find(expr); - ICHECK(it != slot_map_.end()); + TVM_FFI_ICHECK(it != slot_map_.end()); auto* slot = it->second; - ICHECK(slot->value_computed) + TVM_FFI_ICHECK(slot->value_computed) << "PrimExpr " << expr << " in function " << current_gvar_ << " has not been computed"; return {PrimValue::Int64(static_cast(MakeShapeCode::kLoadShape)), PrimValue::Int64(slot->index)}; @@ -455,14 +455,14 @@ class VMShapeLowerMutator } auto it = slot_map_.find(expr); - ICHECK(it != slot_map_.end()); + TVM_FFI_ICHECK(it != slot_map_.end()); auto* slot = it->second; if (slot->value_computed) { return {MatchShapeCode::kAssertEqualToLoad, PrimValue::Int64(slot->index)}; } // the value is not yet computed - ICHECK(!require_value_computed) << "PrimExpr " << expr << " is not computed"; + TVM_FFI_ICHECK(!require_value_computed) << "PrimExpr " << expr << " is not computed"; if (expr.as()) { // It is a var we will populate it in this round. @@ -508,7 +508,7 @@ class VMShapeLowerMutator Expr match_op; if (item.input->struct_info_.as()) { match_op = builtin_match_prim_value_; - ICHECK_EQ(item.pattern.size(), 1); + TVM_FFI_ICHECK_EQ(item.pattern.size(), 1); } else { match_op = builtin_match_shape_; args.push_back(PrimValue::Int64(item.pattern.size())); @@ -541,7 +541,7 @@ class VMShapeLowerMutator std::vector to_compute; for (PrimExprSlot* slot : ready_vars_) { for (PrimExprSlot* user : slot->user_slots) { - ICHECK_GT(user->outstanding_defs, 0); + TVM_FFI_ICHECK_GT(user->outstanding_defs, 0); user->outstanding_defs -= 1; if (user->outstanding_defs == 0) { to_compute.push_back(user); @@ -565,7 +565,7 @@ class VMShapeLowerMutator size_t EmitOutstandingPrimExprCompute() { std::vector to_compute = GetReadyPrimExprSlots(); if (to_compute.size() == 0) return 0; - ICHECK_GT(heap_size_->value, 0); + TVM_FFI_ICHECK_GT(heap_size_->value, 0); // construct a PrimFunc that compute the shape. tir::Var heap("heap", DataType::Handle()); ffi::Array buffer_shape{heap_size_}; @@ -575,13 +575,13 @@ class VMShapeLowerMutator auto var_map = [&](const tir::Var& var) -> ffi::Optional { auto it = slot_map_.find(var); - ICHECK(it != slot_map_.end()); + TVM_FFI_ICHECK(it != slot_map_.end()); return tir::BufferLoad(buffer, {IntImm(ShapeDType(), it->second->index)}); }; ffi::Array seq; for (PrimExprSlot* slot : to_compute) { - ICHECK(!slot->value_computed); + TVM_FFI_ICHECK(!slot->value_computed); slot->value_computed = true; PrimExpr value = tir::Substitute(slot->expr, var_map); seq.push_back(tir::BufferStore(buffer, value, {IntImm(ShapeDType(), slot->index)})); @@ -711,9 +711,10 @@ class VMShapeLowerMutator } else if (op->shape.as()) { // NOTE: This part of the logic is left empty for future support as it is less common. // Future implementors: we can emit a binding here and assert here. - LOG(FATAL) << "Cannot handle Tensor shape pattern where a var appears multiple times"; + TVM_FFI_THROW(InternalError) + << "Cannot handle Tensor shape pattern where a var appears multiple times"; } else { - ICHECK(!op->shape.defined()) << "Can only handle tensor shape pattern var"; + TVM_FFI_ICHECK(!op->shape.defined()) << "Can only handle tensor shape pattern var"; } } @@ -739,8 +740,8 @@ class VMShapeLowerMutator std::vector* match_todos) final { auto* value_tinfo = GetStructInfoAs(value); if (value_tinfo) { - CHECK_EQ(value_tinfo->fields.size(), op->fields.size()) - << "TypeError: " << err_ctx << " during match-cast we find tuple size mismatch"; + TVM_FFI_CHECK_EQ(value_tinfo->fields.size(), op->fields.size(), TypeError) + << err_ctx << " during match-cast we find tuple size mismatch"; } if (always_check || !value_tinfo) { // check_tuple_info(value, tuple_size) diff --git a/src/relax/distributed/axis_group_graph.cc b/src/relax/distributed/axis_group_graph.cc index 12feeacc8b0b..b53f68e44d4b 100644 --- a/src/relax/distributed/axis_group_graph.cc +++ b/src/relax/distributed/axis_group_graph.cc @@ -72,17 +72,17 @@ const TensorStructInfoNode* GetTensorStructInfo(Expr tensor) { if (dtensor_sinfo) { return dtensor_sinfo->tensor_sinfo.get(); } - LOG(FATAL) << tensor << " must be either Tensor or DTesor"; + TVM_FFI_THROW(InternalError) << tensor << " must be either Tensor or DTesor"; throw; } void UnaryOpHelper(ffi::Array tensor_list, distributed::AxisGroupGraph* axis_group_graph) { int n_dim = GetTensorStructInfo(tensor_list[0])->ndim; for (const auto& tensor : tensor_list) { - ICHECK(GetTensorStructInfo(tensor)->ndim == n_dim); + TVM_FFI_ICHECK(GetTensorStructInfo(tensor)->ndim == n_dim); } for (int i = 0; i < n_dim; i++) { - ICHECK(tensor_list.size() <= 2); + TVM_FFI_ICHECK(tensor_list.size() <= 2); for (int j = 0; j < static_cast(tensor_list.size()) - 1; j++) { axis_group_graph->JoinAxis({tensor_list[j].get(), i}, {tensor_list[j + 1].get(), i}, distributed::AxisGroupGraph::EdgeType::kDescend); @@ -122,7 +122,7 @@ void BuildAxisGraphBinary(const Var& output_var, const Call& call, int x2_ndim = x2_sinfo->ndim; const auto* x1_shape = x1_sinfo->shape.as(); const auto* x2_shape = x2_sinfo->shape.as(); - ICHECK(x1_shape && x2_shape); + TVM_FFI_ICHECK(x1_shape && x2_shape); arith::Analyzer analyzer; for (int i = 1; i <= std::min(x1_ndim, x2_ndim); ++i) { const PrimExpr& dim0 = x1_shape->values[x1_ndim - i]; @@ -144,7 +144,7 @@ void BuildAxisGraphBinary(const Var& output_var, const Call& call, {tensor_list[2].get(), std::max(x1_ndim, x2_ndim) - i}, distributed::AxisGroupGraph::EdgeType::kDescend); } else { - LOG(FATAL) << "Invalid broadcast, dim0: " << dim0 << ", dim1: " << dim1; + TVM_FFI_THROW(InternalError) << "Invalid broadcast, dim0: " << dim0 << ", dim1: " << dim1; } } if (x1_ndim > x2_ndim) { @@ -174,7 +174,7 @@ void BuildAxisGraphReduce(const Var& output_var, const Call& call, axes = {attrs->axis}; keepdims = true; } else { - LOG(FATAL) << "Unsupported reduce op: " << call->op; + TVM_FFI_THROW(InternalError) << "Unsupported reduce op: " << call->op; } int ndim = GetTensorStructInfo(input_tensor)->ndim; @@ -182,7 +182,7 @@ void BuildAxisGraphReduce(const Var& output_var, const Call& call, std::unordered_set normalized_axes; for (const Integer& i : axes) { int val = i->value; - ICHECK(val < ndim && val >= -ndim); + TVM_FFI_ICHECK(val < ndim && val >= -ndim); if (val < 0) { val = ndim + val; } @@ -215,7 +215,7 @@ void BuildAxisGraphMatmul(const Var& output_var, const Call& call, const auto* x2_sinfo = GetTensorStructInfo(x2); int x1_ndim = x1_sinfo->ndim; int x2_ndim = x2_sinfo->ndim; - ICHECK(x1_ndim > 0 && x2_ndim > 0); + TVM_FFI_ICHECK(x1_ndim > 0 && x2_ndim > 0); int x1_prepended = 0; int x2_appended = 0; if (x1_ndim == 1) { @@ -228,7 +228,7 @@ void BuildAxisGraphMatmul(const Var& output_var, const Call& call, } const auto* x1_shape = x1_sinfo->shape.as(); const auto* x2_shape = x2_sinfo->shape.as(); - ICHECK(x1_shape && x2_shape); + TVM_FFI_ICHECK(x1_shape && x2_shape); ffi::Array x1_shape_prefix{x1_shape->values.begin(), x1_shape->values.end() - 2 + x1_prepended}; ffi::Array x2_shape_prefix{x2_shape->values.begin(), @@ -257,7 +257,7 @@ void BuildAxisGraphMatmul(const Var& output_var, const Call& call, {x3.get(), std::max(x1_prefix_ndim, x2_prefix_ndim) - i}, distributed::AxisGroupGraph::EdgeType::kDescend); } else { - LOG(FATAL) << "Cannot broadcast " << dim0 << " and " << dim1; + TVM_FFI_THROW(InternalError) << "Cannot broadcast " << dim0 << " and " << dim1; } } // join reduction dim @@ -284,13 +284,13 @@ void BuildAxisGraphPermuteDims(const Var& output_var, const Call& call, distributed::AxisGroupGraph* axis_group_graph) { Expr input_tensor = call->args[0]; const auto* attrs = call->attrs.as(); - ICHECK(attrs); + TVM_FFI_ICHECK(attrs); int ndim = GetTensorStructInfo(input_tensor)->ndim; std::vector normalized_axes; if (attrs->axes.defined()) { for (const Integer& i : attrs->axes.value()) { int val = i->value; - ICHECK(val < ndim && val >= -ndim); + TVM_FFI_ICHECK(val < ndim && val >= -ndim); if (val < 0) { val = ndim + val; } @@ -311,7 +311,7 @@ void BuildAxisGraphReshape(const Var& output_var, const Call& call, const auto* tensor_sinfo = GetTensorStructInfo(input_tensor); const auto* new_shape_sinfo = GetStructInfoAs(call->args[1]); const auto* old_shape_sinfo = GetStructInfoAs(tensor_sinfo->shape.value()); - ICHECK_NOTNULL(old_shape_sinfo); + TVM_FFI_ICHECK_NOTNULL(old_shape_sinfo); ffi::Array old_shape_values = old_shape_sinfo->values.value(); ffi::Array new_shape_values = new_shape_sinfo->values.value(); int i = old_shape_values.size(); diff --git a/src/relax/distributed/global_info.cc b/src/relax/distributed/global_info.cc index 408d31680c79..17a3bd7f3fef 100644 --- a/src/relax/distributed/global_info.cc +++ b/src/relax/distributed/global_info.cc @@ -32,7 +32,7 @@ DeviceMesh::DeviceMesh(ffi::Shape shape, ffi::Array device_ids) { prod *= shape[i]; } ObjectPtr n = ffi::make_object(); - CHECK_EQ(prod, static_cast(device_ids.size())) + TVM_FFI_ICHECK_EQ(prod, static_cast(device_ids.size())) << "The number of device ids must match the product of the shape"; n->shape = std::move(shape); n->device_ids = std::move(device_ids); @@ -51,7 +51,7 @@ DeviceMesh::DeviceMesh(ffi::Shape shape, Range device_range) { for (int i = 0; i < static_cast(shape.size()); i++) { prod *= shape[i]; } - CHECK_EQ(prod, static_cast(device_ids.size())) + TVM_FFI_ICHECK_EQ(prod, static_cast(device_ids.size())) << "The number of device ids must match the product of the shape"; n->device_ids = std::move(device_ids); n->shape = std::move(shape); diff --git a/src/relax/distributed/struct_info.cc b/src/relax/distributed/struct_info.cc index 5c51920fa7e6..731564825015 100644 --- a/src/relax/distributed/struct_info.cc +++ b/src/relax/distributed/struct_info.cc @@ -90,20 +90,20 @@ Placement Placement::FromText(ffi::String text_repr) { } else if (indicator == 'S') { char lbracket; ss >> lbracket; - CHECK_EQ(lbracket, '['); + TVM_FFI_ICHECK_EQ(lbracket, '['); std::string substr; getline(ss, substr, ']'); std::stringstream ss2(substr); int dim; ss2 >> dim; dim_specs.push_back(PlacementSpec::Sharding(dim)); - CHECK(ss2.eof()) << "Invalid placement text repr"; + TVM_FFI_ICHECK(ss2.eof()) << "Invalid placement text repr"; } else if (indicator == ',') { continue; } else if (indicator == ' ') { continue; } else { - LOG(FATAL) << "Invalid placement text repr"; + TVM_FFI_THROW(InternalError) << "Invalid placement text repr"; } } return Placement(dim_specs); @@ -120,12 +120,12 @@ TVM_FFI_STATIC_INIT_BLOCK() { // DTensor DTensorStructInfo::DTensorStructInfo(TensorStructInfo tensor_sinfo, DeviceMesh device_mesh, Placement placement, Span span) { - CHECK_EQ(device_mesh->shape.size(), placement->dim_specs.size()) - << "ValueError: The device mesh and placement must have the same dimension size"; + TVM_FFI_CHECK_EQ(device_mesh->shape.size(), placement->dim_specs.size(), ValueError) + << "The device mesh and placement must have the same dimension size"; for (auto spec : placement->dim_specs) { if (spec->kind == PlacementSpecKind::kReplica) continue; - CHECK_LT(spec->axis, tensor_sinfo->ndim) - << "ValueError: Sharding dimension should be smaller than tensor ndim"; + TVM_FFI_CHECK_LT(spec->axis, tensor_sinfo->ndim, ValueError) + << "Sharding dimension should be smaller than tensor ndim"; } ObjectPtr n = ffi::make_object(); n->device_mesh = std::move(device_mesh); diff --git a/src/relax/distributed/transform/legalize_redistribute.cc b/src/relax/distributed/transform/legalize_redistribute.cc index e9ff5e72f2b9..7faa4697874d 100644 --- a/src/relax/distributed/transform/legalize_redistribute.cc +++ b/src/relax/distributed/transform/legalize_redistribute.cc @@ -67,15 +67,15 @@ class RedistributeLegalizer : public ExprMutator { static Op redistribute_op = Op::Get("relax.dist.redistribute"); if (call->op.same_as(redistribute_op)) { const auto* attrs = call->attrs.as(); - ICHECK(attrs); + TVM_FFI_ICHECK(attrs); const auto* input_sinfo = call->args[0]->struct_info_.as(); - ICHECK(input_sinfo); + TVM_FFI_ICHECK(input_sinfo); // As the first step, we only support redistribute in the same device mesh, // and the device mesh must be 1d // todo: extend the ccl ops so that it can support 2d device mesh, and different sharding // dimension - ICHECK(StructuralEqual()(input_sinfo->device_mesh, attrs->device_mesh)); - ICHECK(input_sinfo->device_mesh->shape.size() == 1); + TVM_FFI_ICHECK(StructuralEqual()(input_sinfo->device_mesh, attrs->device_mesh)); + TVM_FFI_ICHECK(input_sinfo->device_mesh->shape.size() == 1); // only support "S[x]"-> "R" and "R" -> "S[x]" PlacementSpec input_spec = input_sinfo->placement->dim_specs[0]; PlacementSpec output_spec = attrs->placement->dim_specs[0]; @@ -87,21 +87,21 @@ class RedistributeLegalizer : public ExprMutator { output_spec->kind == PlacementSpecKind::kSharding) { // "S[x]" -> "S[y]" if (input_spec->axis != output_spec->axis) { - LOG(FATAL) << "AlltoAll not implemented yet"; + TVM_FFI_THROW(InternalError) << "AlltoAll not implemented yet"; } else { return call->args[0]; } } else if (input_spec->kind == PlacementSpecKind::kSharding && output_spec->kind == PlacementSpecKind::kReplica) { // "S[x]" -> "R" - LOG(FATAL) << "Allgather not implemented yet"; + TVM_FFI_THROW(InternalError) << "Allgather not implemented yet"; } else if (input_spec->kind == PlacementSpecKind::kReplica && output_spec->kind == PlacementSpecKind::kSharding) { // "R" -> "S[x]" return redistribute_replica_to_shard(call->args[0], attrs->device_mesh->shape[0], output_spec->axis); } else { - LOG(FATAL) << "Unsupported redistribute op"; + TVM_FFI_THROW(InternalError) << "Unsupported redistribute op"; } } return call; diff --git a/src/relax/distributed/transform/lower_distir.cc b/src/relax/distributed/transform/lower_distir.cc index 676fce094a5b..83300f80acb2 100644 --- a/src/relax/distributed/transform/lower_distir.cc +++ b/src/relax/distributed/transform/lower_distir.cc @@ -76,7 +76,7 @@ class DistIRSharder : public ExprMutator { TensorStructInfo ShardDTensorSinfo(DTensorStructInfo orig_sinfo) { TensorStructInfo tensor_sinfo = orig_sinfo->tensor_sinfo; - ICHECK(tensor_sinfo->shape); + TVM_FFI_ICHECK(tensor_sinfo->shape); const auto* orig_shape = tensor_sinfo->shape.as(); auto new_tensor_sinfo = ffi::make_object(*tensor_sinfo.get()); new_tensor_sinfo->shape = ShardShape(ffi::GetRef(orig_shape), @@ -111,7 +111,7 @@ class DistIRSharder : public ExprMutator { } Expr ShardInputParamTensorAndConstant(Expr input) { - ICHECK(input->struct_info_); + TVM_FFI_ICHECK(input->struct_info_); StructInfo old_sinfo = GetStructInfo(input); StructInfo new_sinfo = ConvertSinfo(old_sinfo, false); if (const auto* var = input.as()) { @@ -119,19 +119,19 @@ class DistIRSharder : public ExprMutator { return new_param; } else if (const auto* constant = input.as()) { for (const auto& spec : Downcast(old_sinfo)->placement->dim_specs) { - ICHECK(spec->kind == PlacementSpecKind::kReplica); + TVM_FFI_ICHECK(spec->kind == PlacementSpecKind::kReplica); } Constant new_constant(constant->data, new_sinfo); return new_constant; } else { - LOG(FATAL) << "Cannot shard tensor which is not Var or Constant: " << input; + TVM_FFI_THROW(InternalError) << "Cannot shard tensor which is not Var or Constant: " << input; throw; } } void EmitBroadcastOrScatter(Expr old_expr, Expr new_expr, DTensorStructInfo dtensor_sinfo) { // FIXME: this is a hack that only works for 1d device mesh - ICHECK(dtensor_sinfo->device_mesh->shape.size() == 1); + TVM_FFI_ICHECK(dtensor_sinfo->device_mesh->shape.size() == 1); PlacementSpec sharding_spec = dtensor_sinfo->placement->dim_specs[0]; if (sharding_spec->kind == PlacementSpecKind::kReplica) { Var new_var = builder_->Emit(broadcast_from_worker0(new_expr)); @@ -149,7 +149,7 @@ class DistIRSharder : public ExprMutator { tuple_getitem_remap_[Downcast(old_expr)] = scatter_var; } } else { - LOG(FATAL) << "Unsupported placement spec"; + TVM_FFI_THROW(InternalError) << "Unsupported placement spec"; } } @@ -215,9 +215,9 @@ class DistIRSharder : public ExprMutator { static Op call_tir_op = Op::Get("relax.call_tir"); static Op call_tir_local_view_op = Op::Get("relax.dist.call_tir_local_view"); if (call->op.same_as(reshape_op)) { - ICHECK(call->args[1].as()); + TVM_FFI_ICHECK(call->args[1].as()); const auto* out_sinfo = GetStructInfoAs(binding_var); - ICHECK(out_sinfo); + TVM_FFI_ICHECK(out_sinfo); auto new_call_node = ffi::make_object(*call); new_call_node->args.Set(1, ShardShape(Downcast(call->args[1]), out_sinfo->device_mesh, out_sinfo->placement)); @@ -228,7 +228,8 @@ class DistIRSharder : public ExprMutator { new_call_node->sinfo_args = {ConvertSinfo(GetStructInfo(binding_var), true)}; return Call(new_call_node); } else if (call->op.same_as(call_tir_op)) { - LOG(FATAL) << "call_tir should be lowered to call_tir_local_view before lowering to relax"; + TVM_FFI_THROW(InternalError) + << "call_tir should be lowered to call_tir_local_view before lowering to relax"; } else if (const auto* extern_func = call->op.as()) { auto new_call_node = ffi::make_object(*call); if (extern_func->global_symbol == "vm.builtin.distributed.attention_kv_cache_append") { @@ -237,7 +238,7 @@ class DistIRSharder : public ExprMutator { new_call_node->op = ExternFunc("vm.builtin.attention_kv_cache_view"); auto orig_shape = Downcast(call->args[1]); const auto* out_sinfo = GetStructInfoAs(binding_var); - ICHECK(out_sinfo); + TVM_FFI_ICHECK(out_sinfo); ShapeExpr new_shape = ShardShape(orig_shape, out_sinfo->device_mesh, out_sinfo->placement); new_call_node->args.Set(1, new_shape); new_call_node->sinfo_args = {TensorStructInfo(new_shape, out_sinfo->tensor_sinfo->dtype)}; diff --git a/src/relax/distributed/transform/lower_global_view_to_local_view.cc b/src/relax/distributed/transform/lower_global_view_to_local_view.cc index 7dbbbf3a9566..66cd520cebb0 100644 --- a/src/relax/distributed/transform/lower_global_view_to_local_view.cc +++ b/src/relax/distributed/transform/lower_global_view_to_local_view.cc @@ -93,7 +93,7 @@ class DistSBlockInfoCollector : public StmtExprVisitor { void VisitStmt_(const SBlockNode* op) final { for (const auto& iter_var : op->iter_vars) { if (iter_var->iter_type == kCommReduce) { - ICHECK(op->writes.size() == 1); + TVM_FFI_ICHECK(op->writes.size() == 1); reduce_buffer_ = op->writes[0]->buffer; } } @@ -229,7 +229,7 @@ class DistributedBufferCompactor : StmtExprMutator { int dim = pr.first; int shard = pr.second; Var var = GetShardingVarFromIndex(access_index[dim], iter_var_range, &analyzer); - ICHECK(!iter_var_shards_.count(var) || iter_var_shards_[var] == shard) + TVM_FFI_ICHECK(!iter_var_shards_.count(var) || iter_var_shards_[var] == shard) << "A loop cannot have different sharding"; iter_var_shards_[var] = shard; } @@ -242,9 +242,9 @@ class DistributedBufferCompactor : StmtExprMutator { int shard = iter_var_shards_[iter_var->var]; if (shard > 1) { Range dom = iter_var->dom; - ICHECK(is_zero(dom->min)); + TVM_FFI_ICHECK(is_zero(dom->min)); arith::Analyzer analyzer; - ICHECK(analyzer.CanProve(floormod(dom->extent, shard) == 0)); + TVM_FFI_ICHECK(analyzer.CanProve(floormod(dom->extent, shard) == 0)); new_iter_vars.push_back( IterVar(Range::FromMinExtent(dom->min, floordiv(dom->extent, shard)), iter_var->var, iter_var->iter_type, iter_var->thread_tag)); @@ -292,7 +292,7 @@ class DistributedBufferCompactor : StmtExprMutator { // sharding on reduction axis for (const IterVar& iter_var : new_iter_vars) { if (iter_var->iter_type == kCommReduce && iter_var_shards_.count(iter_var->var)) { - ICHECK(add_allreduce_kind_ == ""); + TVM_FFI_ICHECK(add_allreduce_kind_ == ""); AddAllReduceBlock(collector.reduce_kind); break; } @@ -320,7 +320,7 @@ class DistributedBufferCompactor : StmtExprMutator { if (!iter_var_shards_.count(iter_var->var)) { continue; } - ICHECK(iter_value.as()); + TVM_FFI_ICHECK(iter_value.as()); loop_var_shards_[Downcast(iter_value)] = iter_var_shards_[iter_var->var]; } return realize; @@ -332,7 +332,7 @@ class DistributedBufferCompactor : StmtExprMutator { int shard = loop_var_shards_[op->loop_var]; if (shard > 1) { arith::Analyzer analyzer; - ICHECK(analyzer.CanProve(floormod(new_loop->extent, shard) == 0)); + TVM_FFI_ICHECK(analyzer.CanProve(floormod(new_loop->extent, shard) == 0)); new_loop.CopyOnWrite()->extent = floordiv(new_loop->extent, shard); return new_loop; } @@ -386,7 +386,8 @@ class LowerTIRToLocalView : public ExprMutator { } return ret; } else { - LOG(FATAL) << "The output of a call_tir should be a DTensorStructInfo or TupleStructInfo"; + TVM_FFI_THROW(InternalError) + << "The output of a call_tir should be a DTensorStructInfo or TupleStructInfo"; } } @@ -400,7 +401,7 @@ class LowerTIRToLocalView : public ExprMutator { ffi::Array args = Downcast(val->args[1])->fields; for (const auto& arg : args) { const auto* sinfo = GetStructInfoAs(arg); - ICHECK(sinfo); + TVM_FFI_ICHECK(sinfo); sharding_specs.push_back(ShardingSpec(sinfo->device_mesh, sinfo->placement)); } Var output_var = binding->var; diff --git a/src/relax/distributed/transform/propagate_sharding.cc b/src/relax/distributed/transform/propagate_sharding.cc index 1ff614c019c8..703857da9141 100644 --- a/src/relax/distributed/transform/propagate_sharding.cc +++ b/src/relax/distributed/transform/propagate_sharding.cc @@ -188,7 +188,7 @@ class AxisGroupGraphBuilder : public ExprVisitor { if (const auto* tensor_sinfo = binding->var->struct_info_.as()) { tensor_sinfos.push_back(ffi::GetRef(tensor_sinfo)); } else if (const auto* tuple_sinfo = binding->var->struct_info_.as()) { - ICHECK(tuple_sinfo); + TVM_FFI_ICHECK(tuple_sinfo); for (const auto& sinfo : tuple_sinfo->fields) { tensor_sinfos.push_back(Downcast(sinfo)); } @@ -227,7 +227,7 @@ class ShardingAnnotationCollector : public ExprVisitor { static const Op& annotate_sharding_op = Op::Get("relax.dist.annotate_sharding"); if (val->op.same_as(annotate_sharding_op)) { const auto* attrs = val->attrs.as(); - ICHECK(attrs); + TVM_FFI_ICHECK(attrs); for (int i = 0; i < static_cast(attrs->placement->dim_specs.size()); i++) { const PlacementSpec& placement_spec = attrs->placement->dim_specs[i]; @@ -267,9 +267,9 @@ class ShardingConflictHandler : public ExprVisitor { void CheckTensorShardingCompatible(Var var) { const auto* sinfo = GetStructInfoAs(var); - ICHECK(sinfo); + TVM_FFI_ICHECK(sinfo); const auto* shape = sinfo->shape.as(); - ICHECK(shape); + TVM_FFI_ICHECK(shape); int ndim = sinfo->ndim; std::unordered_set sharded_mesh_dim; ffi::Optional device_mesh; @@ -283,7 +283,7 @@ class ShardingConflictHandler : public ExprVisitor { } if (device_mesh.defined()) { - ICHECK(StructuralEqual()(device_mesh.value(), sharding_spec.first)) + TVM_FFI_ICHECK(StructuralEqual()(device_mesh.value(), sharding_spec.first)) << "Sharding conflict detected for tensor " << var->name_hint() << ": Device Mesh mismatch" << ". Conflict Handling logic will be added in the future."; @@ -292,7 +292,7 @@ class ShardingConflictHandler : public ExprVisitor { } if (i >= 0) { int sharding_dim = sharding_spec.second; - ICHECK(sharded_mesh_dim.count(sharding_dim) == 0) + TVM_FFI_ICHECK(sharded_mesh_dim.count(sharding_dim) == 0) << "Sharding conflict detected for tensor " << var->name_hint() << ": Replicate sharding device mesh axis " << sharding_dim << ". Conflict Handling logic will be added in the future."; @@ -313,7 +313,7 @@ class ShardingConflictHandler : public ExprVisitor { int has_sharding_spec; std::tie(sharding_spec, has_sharding_spec) = axis_group_graph_->GetAxisShardingSpec({constant.get(), i}); - ICHECK(!has_sharding_spec) + TVM_FFI_ICHECK(!has_sharding_spec) << "Constant is not allowed to be sharded. Please convert it into an input param."; } } @@ -366,7 +366,8 @@ class DistributedIRBuilder : public ExprMutator { int ndim = sinfo->ndim; DeviceMesh device_mesh = std::get<0>(axis_group_graph_.GetAxisShardingSpec({expr.get(), -1, tuple_idx})).first; - ICHECK(device_mesh.defined()) << expr << "[" << tuple_idx << "] is not assigned device mesh"; + TVM_FFI_ICHECK(device_mesh.defined()) + << expr << "[" << tuple_idx << "] is not assigned device mesh"; ffi::Array placement_specs( std::vector(device_mesh->shape.size(), PlacementSpec::Replica())); for (int i = 0; i < ndim; i++) { @@ -407,7 +408,7 @@ class DistributedIRBuilder : public ExprMutator { Constant new_constant(constant->data, new_sinfo); return new_constant; } else { - LOG(FATAL) << "Cannot rewrite tensor which is not a Var or Constant"; + TVM_FFI_THROW(InternalError) << "Cannot rewrite tensor which is not a Var or Constant"; throw; } } @@ -440,7 +441,7 @@ class DistributedIRBuilder : public ExprMutator { FBuildAxisGraph f = [&](const Var& var, const Call& call, AxisGroupGraph* axis_group_graph) { ffi::Optional prim_func = MatchPrimFunc(this->builder_->GetContextIRModule(), call->args[0]); - ICHECK(prim_func); + TVM_FFI_ICHECK(prim_func); return BuildAxisGraphCallTIR(var, call, prim_func.value(), axis_group_graph); }; Call new_call = Downcast(ExprMutator::VisitExpr_(call)); @@ -499,7 +500,7 @@ class DistributedIRBuilder : public ExprMutator { new_call->struct_info_ = new_dtensor_sinfo; } } else if (call->op.same_as(call_tir_op)) { - ICHECK(call->sinfo_args.size() == 1); + TVM_FFI_ICHECK(call->sinfo_args.size() == 1); if (!SinfoCompatibleWithDistIR(call->sinfo_args)) { ObjectPtr new_call_node = ffi::make_object(*call.get()); if (placements.size() == 1) { @@ -507,7 +508,7 @@ class DistributedIRBuilder : public ExprMutator { Downcast(call->sinfo_args[0]), device_mesh, placements[0])}; } else { const auto* tuple_sinfo = call->sinfo_args[0].as(); - ICHECK(placements.size() == tuple_sinfo->fields.size()); + TVM_FFI_ICHECK(placements.size() == tuple_sinfo->fields.size()); ffi::Array new_tuple_sinfo_fields; for (int i = 0; i < static_cast(placements.size()); i++) { new_tuple_sinfo_fields.push_back(DTensorStructInfo( @@ -537,7 +538,7 @@ class DistributedIRBuilder : public ExprMutator { // get annotated sinfo from axis group graph DeviceMesh device_mesh = std::get<0>(axis_group_graph_.GetAxisShardingSpec({binding->var.get(), -1})).first; - ICHECK(device_mesh.defined()); + TVM_FFI_ICHECK(device_mesh.defined()); ffi::Array placements; // every tuple element has a placement for (int idx = 0; idx < static_cast(orig_output_tensor_sinfos.size()); idx++) { ffi::Array placement_specs( @@ -572,7 +573,7 @@ class DistributedIRBuilder : public ExprMutator { } } else { const auto* inferred_tuple_sinfo = new_call->struct_info_.as(); - ICHECK(inferred_tuple_sinfo) << new_call; + TVM_FFI_ICHECK(inferred_tuple_sinfo) << new_call; Var new_var = builder_->Emit(new_call); var_remap_[binding->var->vid] = new_var; for (int i = 0; i < static_cast(inferred_tuple_sinfo->fields.size()); i++) { diff --git a/src/relax/distributed/transform/utils.cc b/src/relax/distributed/transform/utils.cc index 0bcd730d42c8..fb86bc5cae4a 100644 --- a/src/relax/distributed/transform/utils.cc +++ b/src/relax/distributed/transform/utils.cc @@ -48,7 +48,7 @@ bool SinfoCompatibleWithRelax(ffi::Array sinfos) { bool IsDistIRFunc(Function func) { ffi::Array param_sinfos; for (const auto& param : func->params) { - ICHECK(param->struct_info_); + TVM_FFI_ICHECK(param->struct_info_); param_sinfos.push_back(Downcast(param->struct_info_.value())); } bool compatible_with_dist_ir = SinfoCompatibleWithDistIR(param_sinfos); @@ -58,7 +58,7 @@ bool IsDistIRFunc(Function func) { } else if (compatible_with_dist_ir && !compatible_with_relax) { return true; } else { - LOG(FATAL) << "mixed use of DTensor and Tensor in: " << func; + TVM_FFI_THROW(InternalError) << "mixed use of DTensor and Tensor in: " << func; } } diff --git a/src/relax/ir/binding_rewrite.cc b/src/relax/ir/binding_rewrite.cc index a8dcf78155dd..20bb674da58a 100644 --- a/src/relax/ir/binding_rewrite.cc +++ b/src/relax/ir/binding_rewrite.cc @@ -84,8 +84,8 @@ void DataflowBlockRewriteNode::ReplaceAllUses(Var old_var, Var new_var) { } }; - ICHECK(to_users_.find(old_var) != to_users_.end()) << "Cannot find " << old_var; - ICHECK(to_users_.find(new_var) != to_users_.end()) << "Cannot find " << new_var; + TVM_FFI_ICHECK(to_users_.find(old_var) != to_users_.end()) << "Cannot find " << old_var; + TVM_FFI_ICHECK(to_users_.find(new_var) != to_users_.end()) << "Cannot find " << new_var; // replace uses inside the DataflowBlock. ReplaceAllUsePass replacer(old_var, new_var, dfb_.get()); @@ -142,11 +142,11 @@ void DataflowBlockRewriteNode::Add(Binding binding) { } else if (auto mc = binding.as()) { return std::make_pair(mc->var, mc->value); } - LOG(FATAL) << "Unsupported binding type"; + TVM_FFI_THROW(InternalError) << "Unsupported binding type"; return std::make_pair(Var{}, Expr{}); }(); - ICHECK(0 == to_users_.count(var)) << var << " has been defined so cannot be added."; + TVM_FFI_ICHECK(0 == to_users_.count(var)) << var << " has been defined so cannot be added."; // Add this VarBinding statement after the definition of uses. auto used_vars = GetUsedVars(val); @@ -216,7 +216,7 @@ std::set GetUnusedVars(ffi::Map> users_map, ffi::Array users_map.erase(unused[i]); // remove def site. for (const auto& used_var : used) { - ICHECK(users_map.count(used_var)); + TVM_FFI_ICHECK(users_map.count(used_var)); ffi::Array var_users = users_map[used_var]; // remove the unused var from the use site. if (auto it = std::find(var_users.begin(), var_users.end(), unused[i]); @@ -272,10 +272,11 @@ void DataflowBlockRewriteNode::RemoveUnused(Var unused, bool allow_undef) { // first need to check if this var is used. if (to_users_.count(unused) == 0) { // no def. if (allow_undef) return; - LOG(FATAL) << unused << " undefined. Set allow_undef=True to allow 'removing' undefined var"; + TVM_FFI_THROW(InternalError) + << unused << " undefined. Set allow_undef=True to allow 'removing' undefined var"; } - ICHECK(to_users_[unused].empty()) + TVM_FFI_ICHECK(to_users_[unused].empty()) << unused << " is used by " << to_users_[unused].size() << " vars"; auto old_dfb = dfb_; diff --git a/src/relax/ir/block_builder.cc b/src/relax/ir/block_builder.cc index 56146e80d063..057351e3d069 100644 --- a/src/relax/ir/block_builder.cc +++ b/src/relax/ir/block_builder.cc @@ -104,9 +104,9 @@ class BlockBuilderImpl : public BlockBuilderNode { (*ctx_func_dedup_map_)[func].insert(gvar); return gvar; } else { - ICHECK(it->second.size()) << "Values contained in de-duplication map must be non-empty sets, " - << "but found an empty set for function of type " - << func->GetTypeKey(); + TVM_FFI_ICHECK(it->second.size()) + << "Values contained in de-duplication map must be non-empty sets, " + << "but found an empty set for function of type " << func->GetTypeKey(); // To provide deterministic results, return the GlobalVar that // comes first in lexicographic order. return *std::min_element( @@ -124,11 +124,11 @@ class BlockBuilderImpl : public BlockBuilderNode { if (it != context_mod_->functions.end()) { BaseFunc old_func = (*it).second; auto ptr = ctx_func_dedup_map_->find(old_func); - ICHECK(ptr != ctx_func_dedup_map_->end()) + TVM_FFI_ICHECK(ptr != ctx_func_dedup_map_->end()) << "BlockBuilder::UpdateFunction is updating " << gv << ", which appears in the BlockBuilder's context_mod_, " << "but does not appear in the de-duplication map"; - ICHECK(ptr->second.count(gv)) + TVM_FFI_ICHECK(ptr->second.count(gv)) << "BlockBuilder::UpdateFunction is updating " << gv << ", but the de-duplication map for the previous value of this function " << "does not include " << gv; @@ -154,7 +154,8 @@ class BlockBuilderImpl : public BlockBuilderNode { // the change IRModule in COW. Additionally, we need to be able to // continue use the builder after an error is thrown to avoid state building up. // in an interactive environment. - LOG(FATAL) << diagnostic->message; + throw ffi::Error(diagnostic->error_kind, diagnostic->message, + TVMFFIBacktrace(__FILE__, __LINE__, "", 0)); } //------------------------------- @@ -219,8 +220,8 @@ class BlockBuilderImpl : public BlockBuilderNode { analyzer_.MarkGlobalNonNegValue(shape_var); } else { const PrimExpr& old_shape_expr = (*it).second; - CHECK(old_shape_expr.same_as(shape_expr) || - analyzer_.CanProveEqual(old_shape_expr, shape_expr)) + TVM_FFI_ICHECK(old_shape_expr.same_as(shape_expr) || + analyzer_.CanProveEqual(old_shape_expr, shape_expr)) << "Inconsistent shape var " << shape_var << " in scope: " << old_shape_expr << " vs " << shape_expr; } @@ -246,7 +247,8 @@ class BlockBuilderImpl : public BlockBuilderNode { Var EmitMatchCast(Expr value, StructInfo struct_info, ffi::String name_hint) final { value = this->Normalize(value); - CHECK(StructInfoBaseCheck(GetStructInfo(value), struct_info) != BaseCheckResult::kFailL0) + TVM_FFI_ICHECK(StructInfoBaseCheck(GetStructInfo(value), struct_info) != + BaseCheckResult::kFailL0) << "It is impossible to match cast any value into the target struct_info. " "But got value struct info: " << GetStructInfo(value) << ", given struct info: " << struct_info; @@ -268,7 +270,7 @@ class BlockBuilderImpl : public BlockBuilderNode { Var EmitOutput(Expr output, ffi::String name_hint) final { BindingBlockFrame* cur_frame = CurrentBindingBlockFrame(); - ICHECK(cur_frame->is_dataflow) << "EmitOutput has to be called inside dataflow block."; + TVM_FFI_ICHECK(cur_frame->is_dataflow) << "EmitOutput has to be called inside dataflow block."; return Emit(output, false, name_hint); } @@ -278,28 +280,28 @@ class BlockBuilderImpl : public BlockBuilderNode { if (const auto* var_binding = binding.as()) { if (!cur_frame->is_dataflow) { - ICHECK(!var_binding->var.as()) + TVM_FFI_ICHECK(!var_binding->var.as()) << "Cannot emit dataflow var in non-dataflow block"; } // normalized check - ICHECK(var_binding->var->struct_info_.defined()); - ICHECK(var_binding->value->struct_info_.defined()); + TVM_FFI_ICHECK(var_binding->var->struct_info_.defined()); + TVM_FFI_ICHECK(var_binding->value->struct_info_.defined()); cur_frame->bindings.push_back(binding); binding_table_[var_binding->var->vid] = var_binding->value; } else if (const auto* match_cast = binding.as()) { if (!cur_frame->is_dataflow) { - ICHECK(!match_cast->var.as()) + TVM_FFI_ICHECK(!match_cast->var.as()) << "Cannot emit dataflow var in non-dataflow block"; } // normalized check - ICHECK(match_cast->var->struct_info_.defined()); - ICHECK(match_cast->value->struct_info_.defined()); + TVM_FFI_ICHECK(match_cast->var->struct_info_.defined()); + TVM_FFI_ICHECK(match_cast->value->struct_info_.defined()); // NOTE match shape do not follow simple binding rule // as a result should not appear in binding table. cur_frame->bindings.push_back(binding); AddDefinitionToScope(match_cast->var); } else { - LOG(FATAL) << "Unsupported binding type: " << binding->GetTypeKey(); + TVM_FFI_THROW(InternalError) << "Unsupported binding type: " << binding->GetTypeKey(); } } @@ -369,7 +371,7 @@ class BlockBuilderImpl : public BlockBuilderNode { * then the block frame is no longer valid. */ BindingBlockFrame* CurrentBindingBlockFrame() { - ICHECK(!block_stack_.empty()) << "no block is being built"; + TVM_FFI_ICHECK(!block_stack_.empty()) << "no block is being built"; return &block_stack_.back(); } @@ -378,7 +380,7 @@ class BlockBuilderImpl : public BlockBuilderNode { * \note only use this value */ ScopeFrame* CurrentScopeFrame() { - ICHECK(!scope_stack_.empty()) << "no scope is being opened"; + TVM_FFI_ICHECK(!scope_stack_.empty()) << "no scope is being opened"; return &scope_stack_.back(); } @@ -534,7 +536,7 @@ class Normalizer : public BlockBuilderImpl, private ExprFunctorIsInstance()) { - ICHECK(normalized->struct_info_.defined()) + TVM_FFI_ICHECK(normalized->struct_info_.defined()) << "The struct_info_ of an Expr except OpNode after " "normalization must not be nullptr. However, this Expr does not have struct_info_: " << normalized; @@ -563,7 +565,7 @@ class Normalizer : public BlockBuilderImpl, private ExprFunctorEmit(post, ""); // NOTE: current frame addr can change due to underlying vector // re-allocation, redo lookup @@ -587,7 +589,7 @@ class Normalizer : public BlockBuilderImpl, private ExprFunctorstruct_info_.defined()) + TVM_FFI_ICHECK(var->struct_info_.defined()) << "Var " << var->name_hint() << " does not have struct info."; return ffi::GetRef(var); } @@ -757,9 +759,9 @@ class Normalizer : public BlockBuilderImpl, private ExprFunctorstruct_info_.defined()) { auto opt = MatchStructInfo(node->tuple); - ICHECK(opt) << "The struct info of Tuple must be TupleStructInfo, " - << "but expression " << node->tuple << " has struct info " - << node->tuple->struct_info_; + TVM_FFI_ICHECK(opt) << "The struct info of Tuple must be TupleStructInfo, " + << "but expression " << node->tuple << " has struct info " + << node->tuple->struct_info_; UpdateStructInfo(node, opt.value()->fields[node->index]); } @@ -771,7 +773,7 @@ class Normalizer : public BlockBuilderImpl, private ExprFunctorVisitVarBinding(ffi::GetRef(var_binding)); } else { auto* match_cast = binding.as(); - ICHECK(match_cast) << "Unsupported binding type: " << binding->GetTypeKey(); + TVM_FFI_ICHECK(match_cast) << "Unsupported binding type: " << binding->GetTypeKey(); return this->VisitMatchCast(ffi::GetRef(match_cast)); } } @@ -835,21 +837,21 @@ class Normalizer : public BlockBuilderImpl, private ExprFunctorargs) { - ICHECK(!arg->struct_info_.as()) + TVM_FFI_ICHECK(!arg->struct_info_.as()) << "Distributed operator must take DTensor instead of Tensor as input"; } - ICHECK(op_map_dist_infer_struct_info_.count(op)) + TVM_FFI_ICHECK(op_map_dist_infer_struct_info_.count(op)) << " Cannot find the dist.FInferStructInfo attribute registered to op: " << op->name; return op_map_dist_infer_struct_info_[op](call, ffi::GetRef(this)); } - ICHECK(op_map_infer_struct_info_.count(op)) + TVM_FFI_ICHECK(op_map_infer_struct_info_.count(op)) << " Cannot find the FInferStructInfo attribute registered to op: " << op->name; return op_map_infer_struct_info_[op](call, ffi::GetRef(this)); } else { // derive using function parameters - ICHECK(call->op->struct_info_.defined()); + TVM_FFI_ICHECK(call->op->struct_info_.defined()); auto opt = MatchStructInfo(call->op); - ICHECK(opt) << "Call->op must contains a function struct info"; + TVM_FFI_ICHECK(opt) << "Call->op must contains a function struct info"; FuncStructInfo finfo = opt.value(); return DeriveCallRetStructInfo(finfo, call, ffi::GetRef(this), &analyzer_); } @@ -923,7 +925,7 @@ class Normalizer : public BlockBuilderImpl, private ExprFunctor()) { value = match_cast->value; } else { - LOG(FATAL) << "Unknown binding type: " << binding->GetTypeKey(); + TVM_FFI_THROW(InternalError) << "Unknown binding type: " << binding->GetTypeKey(); } // if we encounter a nested seq, we have to flatten it: // 1. Append the binding block we've accumulated so far @@ -960,7 +962,7 @@ class Normalizer : public BlockBuilderImpl, private ExprFunctorvar << " is defined within a DataflowBlock, " @@ -977,7 +979,7 @@ class Normalizer : public BlockBuilderImpl, private ExprFunctor()) { current.push_back(MatchCast(match_cast->var, seq->body, match_cast->struct_info)); } else { - LOG(FATAL) << "Unknown binding type: " << binding->GetTypeKey(); + TVM_FFI_THROW(InternalError) << "Unknown binding type: " << binding->GetTypeKey(); } } else { current.push_back(binding); @@ -1012,7 +1014,7 @@ class Normalizer : public BlockBuilderImpl, private ExprFunctorbindings.insert(n->bindings.end(), block->bindings.begin(), block->bindings.end()); merged = BindingBlock(n); } else { - LOG(FATAL) << "Unknown block type: " << ret.back()->GetTypeKey(); + TVM_FFI_THROW(InternalError) << "Unknown block type: " << ret.back()->GetTypeKey(); } ret.pop_back(); ret.push_back(merged); diff --git a/src/relax/ir/dataflow_block_rewriter.cc b/src/relax/ir/dataflow_block_rewriter.cc index 249ec14f89dd..2f2d1dac9ae9 100644 --- a/src/relax/ir/dataflow_block_rewriter.cc +++ b/src/relax/ir/dataflow_block_rewriter.cc @@ -195,7 +195,7 @@ static std::optional TryValidate( std::function(const DFPatternNode*)> query_match_state = [&pattern2node, ¤t_match](const DFPatternNode* pattern) -> ffi::Optional { auto it = pattern2node.find(pattern); - ICHECK(it != pattern2node.end()) + TVM_FFI_ICHECK(it != pattern2node.end()) << "DFConstraint attempted to access DFPattern " << ffi::GetRef(pattern) << ", which does not appear in the PatternContext"; const auto& p_node = it->second; @@ -293,7 +293,8 @@ ffi::Optional> MatchGraph(const PatternContext& ctx, const ffi::Array& binding_arr, const ffi::Map& bindings) { // TODO(@ganler): Handle non-may external use. - ICHECK(ctx->allow_extern_use == PatternContextNode::kMay) << "Only kMay is supported yet."; + TVM_FFI_ICHECK(ctx->allow_extern_use == PatternContextNode::kMay) + << "Only kMay is supported yet."; DFPatternMatcher matcher(bindings); MatcherUseDefAnalysis ud_analysis; @@ -353,7 +354,7 @@ ffi::Optional> MatchGraph(const PatternContext& ctx, ffi::Map ret; for (const auto& [pat, p_node] : pattern2node) { - ICHECK(match->matched(p_node)); + TVM_FFI_ICHECK(match->matched(p_node)); ret.Set(ffi::GetRef(pat), ffi::GetRef(match->matched(p_node))); } return ret; diff --git a/src/relax/ir/dataflow_expr_rewriter.cc b/src/relax/ir/dataflow_expr_rewriter.cc index 4aca923a4b80..72f62041dbf0 100644 --- a/src/relax/ir/dataflow_expr_rewriter.cc +++ b/src/relax/ir/dataflow_expr_rewriter.cc @@ -134,8 +134,7 @@ ffi::Array TopologicalSort(const ffi::Array& bindings) { // All bindings should be emitted by this point. If any remain, // then there exists a circular dependency somewhere in the // remaining bindings. - CHECK(delayed_bindings.empty()) << "ValueError: " - << "Bindings contain circular dependency"; + TVM_FFI_CHECK(delayed_bindings.empty(), ValueError) << "Bindings contain circular dependency"; if (required_sorting) { return sorted_bindings; @@ -165,7 +164,7 @@ void RewriteSpec::Append(RewriteSpec other) { // The two rewrites provide the same GlobalVar. // (e.g. Multiple rewrites of the same pattern.) Ensure that // they are referring to the same underlying BaseFunc. - CHECK(func.same_as((*it).second)); + TVM_FFI_ICHECK(func.same_as((*it).second)); } else if (auto new_name = gvar_name_supply->FreshName(gvar->name_hint); new_name != gvar->name_hint) { // The two rewrites provide distinct GlobalVar subroutines, @@ -210,7 +209,8 @@ TVM_FFI_STATIC_INIT_BLOCK() { } else if (auto mod = obj.as()) { return rewriter(mod.value()); } else { - LOG(FATAL) << "Unreachable: object does not contain either variant type"; + TVM_FFI_THROW(InternalError) + << "Unreachable: object does not contain either variant type"; } }); } @@ -326,9 +326,10 @@ OrRewriter::OrRewriter(PatternMatchingRewriter lhs, PatternMatchingRewriter rhs) } RewriteSpec TupleRewriterNode::RewriteBindings(const ffi::Array& bindings) const { - CHECK_LE(patterns.size(), 3) << "For performance reasons, " - << "matching of implicit tuple patterns is currently limited" - << " to tuples with 3 elements or fewer."; + TVM_FFI_ICHECK_LE(patterns.size(), 3) + << "For performance reasons, " + << "matching of implicit tuple patterns is currently limited" + << " to tuples with 3 elements or fewer."; ffi::Map variable_rewrites = GenerateVariableRewrites(bindings); if (variable_rewrites.size()) { @@ -404,7 +405,7 @@ ffi::Map TupleRewriterNode::GenerateVariableRewrites( }; auto decrement_indices = [&](std::vector& indices) -> bool { - ICHECK_EQ(indices.size(), patterns.size()); + TVM_FFI_ICHECK_EQ(indices.size(), patterns.size()); // Step 1, find the first index that can be decremented, while // still generating a valid set of indices. @@ -500,13 +501,13 @@ ffi::Map TupleRewriterNode::GenerateVariableRewrites( if (new_match) { const auto& [indices, exprs] = new_match.value(); - ICHECK_EQ(indices.size(), exprs.size()); + TVM_FFI_ICHECK_EQ(indices.size(), exprs.size()); for (size_t i = 0; i < indices.size(); i++) { - ICHECK_LT(indices[i], info_vec.size()); + TVM_FFI_ICHECK_LT(indices[i], info_vec.size()); auto& info = info_vec[indices[i]]; - ICHECK(!info.used) << "InternalError: " - << "Produced multiple replacements for variable " << info.var; + TVM_FFI_CHECK(!info.used, InternalError) + << "Produced multiple replacements for variable " << info.var; rewrites.Set(info.var, exprs[i]); binding_lookup.erase(info.var); @@ -528,9 +529,9 @@ ffi::Map TupleRewriterNode::GenerateVariableRewrites( std::optional> TupleRewriterNode::TryMatchByBindingIndex( const std::vector& info_vec, const std::vector& indices) const { - ICHECK_GE(indices.size(), 1); + TVM_FFI_ICHECK_GE(indices.size(), 1); - ICHECK_EQ(indices.size(), patterns.size()); + TVM_FFI_ICHECK_EQ(indices.size(), patterns.size()); for (size_t i = 0; i < indices.size(); i++) { const auto& info = info_vec[indices[i]]; if (info.used || !info.matches[i]) { @@ -596,7 +597,7 @@ std::optional> TupleRewriterNode::TryMatchByBindingIndex( std::vector rewrites; if (auto inline_tuple = rewritten.as()) { const auto& fields = inline_tuple->fields; - CHECK_EQ(fields.size(), indices.size()) + TVM_FFI_ICHECK_EQ(fields.size(), indices.size()) << "Expected to receive " << indices.size() << " values to replace TuplePattern with " << indices.size() << " fields, but received " << fields.size() << " values"; rewrites = {fields.begin(), fields.end()}; @@ -658,28 +659,24 @@ PatternMatchingRewriter PatternMatchingRewriter::FromPattern( PatternMatchingRewriter PatternMatchingRewriter::FromModule(IRModule mod) { Function func_pattern = [&]() { - CHECK(mod->ContainGlobalVar("pattern")) - << "KeyError: " + TVM_FFI_CHECK(mod->ContainGlobalVar("pattern"), KeyError) << "Expected module to contain 'pattern', " << "a Relax function defining the pattern to be matched, " << "but the module did not contain a 'pattern' function."; auto base_func = mod->Lookup("pattern"); - CHECK(base_func->IsInstance()) - << "TypeError: " + TVM_FFI_CHECK(base_func->IsInstance(), TypeError) << "Expected module to contain 'pattern', " << "a Relax function defining the pattern to be matched, " << "but the 'pattern' function was of type " << base_func->GetTypeKey() << "."; return Downcast(base_func); }(); Function func_replacement = [&]() { - CHECK(mod->ContainGlobalVar("replacement")) - << "KeyError: " + TVM_FFI_CHECK(mod->ContainGlobalVar("replacement"), KeyError) << "Expected module to contain 'replacement', " << "a Relax function defining the replacement to be matched, " << "but the module did not contain a 'replacement' function."; auto base_func = mod->Lookup("replacement"); - CHECK(base_func->IsInstance()) - << "TypeError: " + TVM_FFI_CHECK(base_func->IsInstance(), TypeError) << "Expected module to contain 'replacement', " << "a Relax function defining the replacement to be made on a successful match, " << "but the 'replacement' function was of type " << base_func->GetTypeKey() << "."; @@ -690,19 +687,18 @@ PatternMatchingRewriter PatternMatchingRewriter::FromModule(IRModule mod) { for (const auto& [gvar, func] : mod->functions) { if (gvar->name_hint != "pattern" && gvar->name_hint != "replacement") { bool is_public = func->GetAttr(tvm::attr::kGlobalSymbol).has_value(); - CHECK(!is_public) << "ValueError: " - << "Expected module to have no publicly-exposed functions " - << "other than 'pattern' and 'replacement'. " - << "However, function '" << gvar->name_hint << "' of type " - << func->GetTypeKey() << " is publicly exposed."; + TVM_FFI_CHECK(!is_public, ValueError) + << "Expected module to have no publicly-exposed functions " + << "other than 'pattern' and 'replacement'. " + << "However, function '" << gvar->name_hint << "' of type " << func->GetTypeKey() + << " is publicly exposed."; new_subroutines.Set(gvar, func); } } auto sinfo_pattern = GetStructInfo(func_pattern); auto sinfo_replacement = GetStructInfo(func_replacement); - CHECK(StructuralEqual()(sinfo_pattern, sinfo_replacement)) - << "ValueError: " + TVM_FFI_CHECK(StructuralEqual()(sinfo_pattern, sinfo_replacement), ValueError) << "The pattern and replacement must have the same signature, " << "but the pattern has struct info " << sinfo_pattern << ", while the replacement has struct info " << sinfo_replacement; @@ -742,9 +738,8 @@ PatternMatchingRewriter PatternMatchingRewriter::FromModule(IRModule mod) { return StructInfoPattern(WildcardPattern(), PrimStructInfo(prim->value)); } else { - LOG(FATAL) << "TypeError: " - << "Cannot convert Relax expression of type " << expr->GetTypeKey() - << " into pattern-matching rule."; + TVM_FFI_THROW(TypeError) << "Cannot convert Relax expression of type " << expr->GetTypeKey() + << " into pattern-matching rule."; } }; @@ -769,7 +764,7 @@ PatternMatchingRewriter PatternMatchingRewriter::FromModule(IRModule mod) { ffi::Array new_blocks; ffi::Array wildcard_bindings; - ICHECK_EQ(param_wildcards.size(), func_replacement->params.size()); + TVM_FFI_ICHECK_EQ(param_wildcards.size(), func_replacement->params.size()); for (size_t i = 0; i < param_wildcards.size(); i++) { Expr matched_expr = matches[param_wildcards[i]]; @@ -1002,7 +997,7 @@ class PatternMatchingMutator : public ExprMutator { } else if (auto match_cast = binding.as()) { builder_->EmitNormalized(MatchCast(binding->var, value, match_cast->struct_info)); } else { - LOG(FATAL) << "Binding must be either VarBinding or MatchCast"; + TVM_FFI_THROW(InternalError) << "Binding must be either VarBinding or MatchCast"; } } return builder_->EndBlock(); @@ -1021,7 +1016,7 @@ class PatternMatchingMutator : public ExprMutator { auto last_binding = last_block->bindings.back(); last_block.CopyOnWrite()->bindings.pop_back(); - ICHECK(last_binding->var.same_as(dummy_output_var)); + TVM_FFI_ICHECK(last_binding->var.same_as(dummy_output_var)); if (last_block->bindings.size()) { new_blocks.push_back(last_block); @@ -1045,10 +1040,11 @@ Expr PatternMatchingRewriter::operator()(Expr expr) { PatternMatchingMutator mutator(get()); auto new_expr = mutator(expr); auto new_subroutines = mutator.GetNewSubroutines(); - CHECK_EQ(new_subroutines.size(), 0) << "If PatternMatchingRewriter provides subroutines, " - << "then it must be applied to an entire IRModule. " - << "However, PatternMatchingRewriter produced subroutines " - << [&]() -> ffi::Array { + TVM_FFI_ICHECK_EQ(new_subroutines.size(), 0) + << "If PatternMatchingRewriter provides subroutines, " + << "then it must be applied to an entire IRModule. " + << "However, PatternMatchingRewriter produced subroutines " + << [&]() -> ffi::Array { std::vector vec; for (const auto& [gvar, func] : new_subroutines) { vec.push_back(gvar); diff --git a/src/relax/ir/dataflow_matcher.cc b/src/relax/ir/dataflow_matcher.cc index 5c0fd6d8f554..2f7099937fec 100644 --- a/src/relax/ir/dataflow_matcher.cc +++ b/src/relax/ir/dataflow_matcher.cc @@ -145,7 +145,7 @@ void DFPatternMatcher::ClearMap(size_t watermark) { } bool DFPatternMatcher::VisitDFPattern(const DFPattern& pattern, const Expr& expr0) { - CHECK(pattern.defined()) << "Null pattern found when matching against " << expr0; + TVM_FFI_ICHECK(pattern.defined()) << "Null pattern found when matching against " << expr0; auto expr = UnwrapBindings(expr0, var2val_); if (memoize_ && memo_.count(pattern)) { @@ -414,7 +414,8 @@ bool DFPatternMatcher::VisitDFPattern_(const UnorderedTuplePatternNode* op, cons if (const auto* tuple_node = expr.as()) { if (op->fields.size() == tuple_node->fields.size()) { constexpr int8_t kUnknown = -1; - ICHECK_LE(op->fields.size(), std::numeric_limits::max()) << "Too many fields!"; + TVM_FFI_ICHECK_LE(op->fields.size(), std::numeric_limits::max()) + << "Too many fields!"; // dynamic programming. std::vector match_cache(op->fields.size() * op->fields.size(), kUnknown); std::vector field_match_bitmap(op->fields.size(), false); diff --git a/src/relax/ir/dataflow_pattern.cc b/src/relax/ir/dataflow_pattern.cc index 99e7dc6dfe05..ce34f4705fba 100644 --- a/src/relax/ir/dataflow_pattern.cc +++ b/src/relax/ir/dataflow_pattern.cc @@ -483,7 +483,7 @@ ffi::Optional PatternContext::Current() { PatternContext::PatternContext(bool incremental) { auto n = ffi::make_object(); if (incremental) { - ICHECK(!pattern_ctx_stack().empty()) + TVM_FFI_ICHECK(!pattern_ctx_stack().empty()) << "Incremental context needs to be built inside a existing context."; n->allow_extern_use = pattern_ctx_stack().top()->allow_extern_use; n->edge_constraints = pattern_ctx_stack().top()->edge_constraints; @@ -496,7 +496,7 @@ PatternContext::PatternContext(bool incremental) { void PatternContext::EnterWithScope() const { pattern_ctx_stack().push(*this); } void PatternContext::ExitWithScope() const { - ICHECK(pattern_ctx_stack().top().same_as(*this)); + TVM_FFI_ICHECK(pattern_ctx_stack().top().same_as(*this)); pattern_ctx_stack().pop(); } @@ -513,7 +513,7 @@ PatternSeq::PatternSeq(DFPattern init_pattern) { data_ = std::move(n); } PatternSeq::PatternSeq(tvm::ffi::Array patterns, bool only_used_by) { - ICHECK_GE(patterns.size(), 1) << "PatternSeq must have at least one pattern"; + TVM_FFI_ICHECK_GE(patterns.size(), 1) << "PatternSeq must have at least one pattern"; const auto cons = PairCons(only_used_by ? PairCons::kOnlyUsedBy : PairCons::kUsedBy); ObjectPtr n = ffi::make_object(); diff --git a/src/relax/ir/emit_te.cc b/src/relax/ir/emit_te.cc index ee10a97aa0e7..7fcc5e8b9b62 100644 --- a/src/relax/ir/emit_te.cc +++ b/src/relax/ir/emit_te.cc @@ -59,12 +59,13 @@ te::Tensor TETensor(Expr value, ffi::Map tir_var_map, std::s n->shape = std::move(shape); return te::PlaceholderOp(n).output(0); } - ICHECK(value->struct_info_.defined()) << "value must be normalized and contain StructInfo"; + TVM_FFI_ICHECK(value->struct_info_.defined()) + << "value must be normalized and contain StructInfo"; auto* tensor_sinfo = GetStructInfoAs(value); - ICHECK(tensor_sinfo) << "Value must be a tensor"; + TVM_FFI_ICHECK(tensor_sinfo) << "Value must be a tensor"; auto* shape_expr = tensor_sinfo->shape.as(); - CHECK(shape_expr) - << "ValueError: Expression does not have an known symbolic shape, please consider use " + TVM_FFI_CHECK(shape_expr, ValueError) + << "Expression does not have an known symbolic shape, please consider use " "match_cast " << "to constrain the shape before passing into te_tensor"; n->shape = shape_expr->values.Map( diff --git a/src/relax/ir/expr.cc b/src/relax/ir/expr.cc index 1d5b715cfb36..b011327e8db1 100644 --- a/src/relax/ir/expr.cc +++ b/src/relax/ir/expr.cc @@ -60,8 +60,8 @@ Id::Id(ffi::String name_hint) { Call::Call(Expr op, ffi::Array args, Attrs attrs, ffi::Array sinfo_args, Span span) { - CHECK(!op->struct_info_.defined() || op->struct_info_->IsInstance()) - << "ValueError: " + TVM_FFI_CHECK(!op->struct_info_.defined() || op->struct_info_->IsInstance(), + ValueError) << "Call expects its operator to have FuncStructInfo, " << "but operator " << op << ", which was called with arguments " << args << ", has struct info " << op->struct_info_; @@ -213,12 +213,12 @@ Tuple WithFields(Tuple tuple, ffi::Optional> opt_fields, } TupleGetItem::TupleGetItem(Expr tuple, int index, Span span) { - CHECK_GE(index, 0) << "Index out of bounds: Tuple " << tuple - << " cannot be accessed with negative index " << index; + TVM_FFI_ICHECK_GE(index, 0) << "Index out of bounds: Tuple " << tuple + << " cannot be accessed with negative index " << index; ObjectPtr n = ffi::make_object(); if (auto* tuple_info = tuple->struct_info_.as()) { - CHECK_LT(index, tuple_info->fields.size()) + TVM_FFI_ICHECK_LT(index, tuple_info->fields.size()) << "Index out of bounds: Tuple " << tuple << " is of size " << tuple_info->fields.size() << ", and cannot be accessed with index " << index; auto sinfo = tuple_info->fields[index]; @@ -261,7 +261,7 @@ ShapeExpr::ShapeExpr(ffi::Array values, Span span) { if (value->IsInstance()) { return tvm::cast(DataType::Int(64), value); } - ICHECK(value.dtype() == DataType::Int(64)) + TVM_FFI_ICHECK(value.dtype() == DataType::Int(64)) << "the value in ShapeStructInfo can only have dtype of int64"; return value; }); @@ -291,7 +291,7 @@ VarNode* Var::CopyOnWrite() { // If the `TVM_DEFINE_OBJECT_REF_COW_METHOD` were used, the // automatic implementation would erroneously convert from a // `DataflowBlock` to a `Var`. - ICHECK(data_ != nullptr); + TVM_FFI_ICHECK(data_ != nullptr); if (!data_.unique()) { ObjectPtr node; if (auto dataflow_var = as()) { @@ -413,7 +413,7 @@ TVM_FFI_STATIC_INIT_BLOCK() { MatchCast::MatchCast(Var var, Expr value, StructInfo struct_info, Span span) { ObjectPtr n = ffi::make_object(); - ICHECK(var.defined()) << "MatchCast requires var to be defined"; + TVM_FFI_ICHECK(var.defined()) << "MatchCast requires var to be defined"; n->var = std::move(var); n->value = std::move(value); n->struct_info = std::move(struct_info); @@ -485,7 +485,7 @@ BindingBlockNode* BindingBlock::CopyOnWrite() { // If the `TVM_DEFINE_OBJECT_REF_COW_METHOD` were used, the // automatic implementation would erroneously convert from a // `DataflowBlock` to a `BindingBlock`. - ICHECK(data_ != nullptr); + TVM_FFI_ICHECK(data_ != nullptr); if (!data_.unique()) { ObjectPtr node; if (auto dataflow_block = as()) { @@ -554,7 +554,7 @@ Function::Function(ffi::Array params, Expr body, ffi::Optional ffi::Array param_sinfo; for (const Var& param : params) { - CHECK(param->struct_info_.defined()) + TVM_FFI_ICHECK(param->struct_info_.defined()) << "relax.Function requires params to contain struct_info_"; param_sinfo.push_back(GetStructInfo(param)); } @@ -565,7 +565,7 @@ Function::Function(ffi::Array params, Expr body, ffi::Optional body_sinfo = GetStructInfo(body); } - CHECK(body_sinfo.defined() || ret_struct_info.defined()) + TVM_FFI_ICHECK(body_sinfo.defined() || ret_struct_info.defined()) << "Function must be constructed with either " << "an explicit struct info for the return type, " << "or a normalized body with struct info."; @@ -623,7 +623,7 @@ Function Function::CreateEmpty(ffi::Array params, StructInfo ret_struct_inf DictAttrs attrs, Span span) { ffi::Array param_sinfo; for (const Var& param : params) { - ICHECK(param->struct_info_.defined()) + TVM_FFI_ICHECK(param->struct_info_.defined()) << "relax.Function requires params to contain struct_info_."; param_sinfo.push_back(GetStructInfo(param)); } @@ -665,7 +665,7 @@ TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("tvm.relax.struct_info.infer_by_sinfo_args", [](const Call& call, const BlockBuilder& ctx) -> StructInfo { - ICHECK(call->sinfo_args.defined()) + TVM_FFI_ICHECK(call->sinfo_args.defined()) << "sinfo_args field of CallNode should always be defined"; if (call->sinfo_args.empty()) { return ObjectStructInfo(); @@ -689,7 +689,7 @@ ExternFunc::ExternFunc(ffi::String global_symbol, Span span) : ExternFunc(global_symbol, GetExternFuncStructInfo(), span) {} ExternFunc::ExternFunc(ffi::String global_symbol, StructInfo struct_info, Span span) { - CHECK(struct_info.as()) + TVM_FFI_ICHECK(struct_info.as()) << "ExternFunc must have FuncStructInfo, " << "but declaration of '" << global_symbol << "' received " << struct_info; @@ -714,10 +714,11 @@ TVM_FFI_STATIC_INIT_BLOCK() { Expr GetShapeOf(const Expr& expr) { // default case, to be normalized. - ICHECK(expr->struct_info_.defined()) << "GetShapeOf can only be applied to normalized expr"; + TVM_FFI_ICHECK(expr->struct_info_.defined()) + << "GetShapeOf can only be applied to normalized expr"; auto* tinfo = GetStructInfoAs(expr); - ICHECK(tinfo != nullptr) << "ShapeOf can only be applied to expr with TensorStructInfo"; + TVM_FFI_ICHECK(tinfo != nullptr) << "ShapeOf can only be applied to expr with TensorStructInfo"; if (tinfo->shape.defined()) return tinfo->shape.value(); static const Op& op = Op::Get("relax.shape_of"); diff --git a/src/relax/ir/expr_functor.cc b/src/relax/ir/expr_functor.cc index 6ebc56feebe2..13ef41eede43 100644 --- a/src/relax/ir/expr_functor.cc +++ b/src/relax/ir/expr_functor.cc @@ -37,34 +37,34 @@ self->VisitBinding_(binding, static_cast(n.get())); \ }); -#define RELAX_VAR_BINDING_DISPATCH_IMPL(Type) \ - Type::VisitBindingVTable Type::InitVisitBindingVTable() { \ - VisitBindingVTable vtable; \ - RELAX_VISIT_BINDING_DISPATCH(ConstantNode); \ - RELAX_VISIT_BINDING_DISPATCH(TupleNode); \ - RELAX_VISIT_BINDING_DISPATCH(VarNode); \ - RELAX_VISIT_BINDING_DISPATCH(DataflowVarNode); \ - RELAX_VISIT_BINDING_DISPATCH(ShapeExprNode); \ - RELAX_VISIT_BINDING_DISPATCH(ExternFuncNode); \ - RELAX_VISIT_BINDING_DISPATCH(GlobalVarNode); \ - RELAX_VISIT_BINDING_DISPATCH(FunctionNode); \ - RELAX_VISIT_BINDING_DISPATCH(CallNode); \ - RELAX_VISIT_BINDING_DISPATCH(SeqExprNode); \ - RELAX_VISIT_BINDING_DISPATCH(IfNode); \ - RELAX_VISIT_BINDING_DISPATCH(OpNode); \ - RELAX_VISIT_BINDING_DISPATCH(TupleGetItemNode); \ - RELAX_VISIT_BINDING_DISPATCH(PrimValueNode); \ - RELAX_VISIT_BINDING_DISPATCH(StringImmNode); \ - RELAX_VISIT_BINDING_DISPATCH(DataTypeImmNode); \ - return vtable; \ - } \ - void Type::VisitBinding_(const VarBindingNode* binding) { \ - static VisitBindingVTable vtable = InitVisitBindingVTable(); \ - const Expr& value = binding->value; \ - ICHECK(value.defined()) << "Found null pointer node while traversing AST."; \ - ICHECK(vtable.can_dispatch(value)) \ - << "VisitVarBinding do not allow binding value type" << value->GetTypeKey(); \ - vtable(value, this, binding); \ +#define RELAX_VAR_BINDING_DISPATCH_IMPL(Type) \ + Type::VisitBindingVTable Type::InitVisitBindingVTable() { \ + VisitBindingVTable vtable; \ + RELAX_VISIT_BINDING_DISPATCH(ConstantNode); \ + RELAX_VISIT_BINDING_DISPATCH(TupleNode); \ + RELAX_VISIT_BINDING_DISPATCH(VarNode); \ + RELAX_VISIT_BINDING_DISPATCH(DataflowVarNode); \ + RELAX_VISIT_BINDING_DISPATCH(ShapeExprNode); \ + RELAX_VISIT_BINDING_DISPATCH(ExternFuncNode); \ + RELAX_VISIT_BINDING_DISPATCH(GlobalVarNode); \ + RELAX_VISIT_BINDING_DISPATCH(FunctionNode); \ + RELAX_VISIT_BINDING_DISPATCH(CallNode); \ + RELAX_VISIT_BINDING_DISPATCH(SeqExprNode); \ + RELAX_VISIT_BINDING_DISPATCH(IfNode); \ + RELAX_VISIT_BINDING_DISPATCH(OpNode); \ + RELAX_VISIT_BINDING_DISPATCH(TupleGetItemNode); \ + RELAX_VISIT_BINDING_DISPATCH(PrimValueNode); \ + RELAX_VISIT_BINDING_DISPATCH(StringImmNode); \ + RELAX_VISIT_BINDING_DISPATCH(DataTypeImmNode); \ + return vtable; \ + } \ + void Type::VisitBinding_(const VarBindingNode* binding) { \ + static VisitBindingVTable vtable = InitVisitBindingVTable(); \ + const Expr& value = binding->value; \ + TVM_FFI_ICHECK(value.defined()) << "Found null pointer node while traversing AST."; \ + TVM_FFI_ICHECK(vtable.can_dispatch(value)) \ + << "VisitVarBinding do not allow binding value type" << value->GetTypeKey(); \ + vtable(value, this, binding); \ } // functions to be overriden. @@ -286,7 +286,7 @@ void ExprVisitor::VisitBinding(const Binding& binding) { } else if (const auto* node = binding.as()) { VisitBinding_(node); } else { - LOG(FATAL) << "TypeError: Invalid type: " << binding->GetTypeKey(); + TVM_FFI_THROW(TypeError) << "Invalid type: " << binding->GetTypeKey(); } } @@ -296,7 +296,7 @@ void ExprVisitor::VisitBindingBlock(const BindingBlock& block) { } else if (const auto* node = block.as()) { VisitBindingBlock_(node); } else { - LOG(FATAL) << "TypeError: Invalid type: " << block->GetTypeKey(); + TVM_FFI_THROW(TypeError) << "Invalid type: " << block->GetTypeKey(); } } @@ -306,7 +306,7 @@ void ExprVisitor::VisitVarDef(const Var& var) { } else if (const auto* node = var.as()) { VisitVarDef_(node); } else { - LOG(FATAL) << "TypeError: Invalid type: " << var->GetTypeKey(); + TVM_FFI_THROW(TypeError) << "Invalid type: " << var->GetTypeKey(); } } @@ -531,11 +531,11 @@ BindingBlock ExprMutatorBase::VisitBindingBlock(const BindingBlock& block) { Expr new_value = this->VisitExpr(match_cast->value); bindings.push_back(MatchCast(match_cast->var, new_value, match_cast->struct_info)); } else { - LOG(FATAL) << "TypeError: Invalid type: " << binding->GetTypeKey(); + TVM_FFI_THROW(TypeError) << "Invalid type: " << binding->GetTypeKey(); } } } else { - LOG(FATAL) << "TypeError: Invalid type: " << block->GetTypeKey(); + TVM_FFI_THROW(TypeError) << "Invalid type: " << block->GetTypeKey(); } if (block.as()) { @@ -677,8 +677,7 @@ void ExprMutator::ReEmitBinding(const VarBindingNode* binding, Expr new_value) { auto new_sinfo = new_value->struct_info_.as(); - ICHECK(new_sinfo) - << "InternalError: " + TVM_FFI_CHECK(new_sinfo, InternalError) << "In binding of variable " << binding->var << ", the value " << new_value << " does not have StructInfo. " << "This typically occurs when ReEmitBinding is called without first calling Normalize."; @@ -766,7 +765,7 @@ void ExprMutator::VisitBinding(const Binding& binding) { } else if (const auto* node = binding.as()) { VisitBinding_(node); } else { - LOG(FATAL) << "TypeError: Invalid type: " << binding->GetTypeKey(); + TVM_FFI_THROW(TypeError) << "Invalid type: " << binding->GetTypeKey(); } } @@ -777,7 +776,7 @@ BindingBlock ExprMutator::VisitBindingBlock(const BindingBlock& block) { } else if (const auto* node = block.as()) { ret = VisitBindingBlock_(node); } else { - LOG(FATAL) << "TypeError: Invalid type: " << block->GetTypeKey(); + TVM_FFI_THROW(TypeError) << "Invalid type: " << block->GetTypeKey(); } return ret; } @@ -789,13 +788,13 @@ Var ExprMutator::VisitVarDef(const Var& var) { } else if (const auto* node = var.as()) { ret = VisitVarDef_(node); } else { - LOG(FATAL) << "TypeError: Invalid type: " << var->GetTypeKey(); + TVM_FFI_THROW(TypeError) << "Invalid type: " << var->GetTypeKey(); } return ret; } Expr ExprMutator::VisitWithNewScope(const Expr& expr, ffi::Optional> params) { - ICHECK(expr->IsInstance()) + TVM_FFI_ICHECK(expr->IsInstance()) << "Normal form requires all new scope is stored as SeqExpr"; PrimExpr constraint = Bool(true); @@ -829,7 +828,7 @@ Expr ExprMutator::VisitWithNewScope(const Expr& expr, ffi::OptionalIsInstance()) + TVM_FFI_ICHECK(expr->IsInstance()) << "Normal form requires all new scope is stored as SeqExpr"; builder_->BeginInnerScope(); @@ -843,7 +842,7 @@ ffi::Optional ExprMutator::LookupBinding(const Var& var) { } Var ExprMutator::WithStructInfo(Var var, StructInfo struct_info) { - ICHECK(struct_info.defined()); + TVM_FFI_ICHECK(struct_info.defined()); // TODO(relax-team) add StructInfoEqual check if (var->struct_info_.defined()) { diff --git a/src/relax/ir/py_expr_functor.cc b/src/relax/ir/py_expr_functor.cc index b7d61bfda8ec..7eca9af0f4d7 100644 --- a/src/relax/ir/py_expr_functor.cc +++ b/src/relax/ir/py_expr_functor.cc @@ -576,7 +576,7 @@ TVM_FFI_STATIC_INIT_BLOCK() { } else if (const auto* ptr = binding.as()) { visitor->ExprVisitor::VisitBinding_(ptr); } else { - LOG(FATAL) << "unreachable"; + TVM_FFI_THROW(InternalError) << "unreachable"; } }) .def("relax.ExprVisitorVisitBindingBlock", @@ -586,7 +586,7 @@ TVM_FFI_STATIC_INIT_BLOCK() { } else if (const auto* ptr = block.as()) { visitor->ExprVisitor::VisitBindingBlock_(ptr); } else { - LOG(FATAL) << "TypeError: Invalid type: " << block->GetTypeKey(); + TVM_FFI_THROW(TypeError) << "Invalid type: " << block->GetTypeKey(); } }) .def("relax.ExprVisitorVisitVarDef", @@ -596,7 +596,7 @@ TVM_FFI_STATIC_INIT_BLOCK() { } else if (const auto* node = var.as()) { visitor->ExprVisitor::VisitVarDef_(node); } else { - LOG(FATAL) << "TypeError: Invalid type: " << var->GetTypeKey(); + TVM_FFI_THROW(TypeError) << "Invalid type: " << var->GetTypeKey(); } }) .def("relax.ExprVisitorVisitSpan", @@ -623,7 +623,7 @@ TVM_FFI_STATIC_INIT_BLOCK() { } else if (const auto* ptr = binding.as()) { return mutator->ExprMutator::VisitBinding_(ptr); } else { - LOG(FATAL) << "unreachable"; + TVM_FFI_THROW(InternalError) << "unreachable"; } }) .def("relax.ExprMutatorVisitBindingBlock", @@ -633,7 +633,7 @@ TVM_FFI_STATIC_INIT_BLOCK() { } else if (const auto* node = block.as()) { return mutator->ExprMutator::VisitBindingBlock_(node); } else { - LOG(FATAL) << "TypeError: Invalid type: " << block->GetTypeKey(); + TVM_FFI_THROW(TypeError) << "Invalid type: " << block->GetTypeKey(); } }) .def("relax.ExprMutatorVisitVarDef", @@ -643,7 +643,7 @@ TVM_FFI_STATIC_INIT_BLOCK() { } else if (const auto* node = var.as()) { return mutator->ExprMutator::VisitVarDef_(node); } else { - LOG(FATAL) << "TypeError: Invalid type: " << var->GetTypeKey(); + TVM_FFI_THROW(TypeError) << "Invalid type: " << var->GetTypeKey(); } }) .def( diff --git a/src/relax/ir/struct_info.cc b/src/relax/ir/struct_info.cc index 22ed4e9ea382..434917bd0f94 100644 --- a/src/relax/ir/struct_info.cc +++ b/src/relax/ir/struct_info.cc @@ -85,7 +85,7 @@ ShapeStructInfo::ShapeStructInfo(ffi::Array values, Span span) { if (value->IsInstance()) { return tvm::cast(DataType::Int(64), value); } - ICHECK(value.dtype() == DataType::Int(64)) + TVM_FFI_ICHECK(value.dtype() == DataType::Int(64)) << "the value in ShapeStructInfo can only have dtype of int64"; return value; }); @@ -95,7 +95,7 @@ ShapeStructInfo::ShapeStructInfo(ffi::Array values, Span span) { ShapeStructInfo::ShapeStructInfo(int ndim, Span span) { ObjectPtr n = ffi::make_object(); - CHECK_GE(ndim, -1) << "ndim of ShapeStructInfo must be >= -1, but got " << ndim; + TVM_FFI_ICHECK_GE(ndim, -1) << "ndim of ShapeStructInfo must be >= -1, but got " << ndim; n->ndim = ndim; n->span = span; data_ = std::move(n); @@ -106,7 +106,7 @@ TVM_FFI_STATIC_INIT_BLOCK() { refl::GlobalDef().def( "relax.ShapeStructInfo", [](ffi::Optional> values, int ndim, Span span) { if (values.defined()) { - CHECK_EQ(ndim, kUnknownNDim) << "ValueError: Cannot both specify values and ndim"; + TVM_FFI_CHECK_EQ(ndim, kUnknownNDim, ValueError) << "Cannot both specify values and ndim"; return ShapeStructInfo(values.value(), span); } else { return ShapeStructInfo(ndim, span); @@ -120,9 +120,9 @@ TensorStructInfo::TensorStructInfo(Expr shape, DataType dtype, ffi::Optional n = ffi::make_object(); // assign ndim before move ffi::Optional sinfo = MatchStructInfo(shape); - ICHECK(sinfo) << "We expect shape to contain pre-set shape struct info"; - ICHECK(shape.defined()) << "Must provide a shape in this constructor"; - ICHECK(shape->IsInstance() || shape->IsInstance()) + TVM_FFI_ICHECK(sinfo) << "We expect shape to contain pre-set shape struct info"; + TVM_FFI_ICHECK(shape.defined()) << "Must provide a shape in this constructor"; + TVM_FFI_ICHECK(shape->IsInstance() || shape->IsInstance()) << "We require shape to be normalized when constructing TensorStructInfo"; n->ndim = sinfo.value()->ndim; // assign rest of the fields. @@ -136,7 +136,7 @@ TensorStructInfo::TensorStructInfo(Expr shape, DataType dtype, ffi::Optional vdevice, Span span) { ObjectPtr n = ffi::make_object(); - CHECK_GE(ndim, -1) << "ndim of TensorStructInfo must be >= -1, but got " << ndim; + TVM_FFI_ICHECK_GE(ndim, -1) << "ndim of TensorStructInfo must be >= -1, but got " << ndim; n->ndim = ndim; n->dtype = dtype; n->vdevice = vdevice; @@ -150,7 +150,7 @@ TVM_FFI_STATIC_INIT_BLOCK() { "relax.TensorStructInfo", [](ffi::Optional shape, ffi::Optional dtype, int ndim, VDevice vdevice, Span span) { if (shape.defined()) { - CHECK_EQ(ndim, kUnknownNDim) << "ValueError: Cannot both specify shape and ndim"; + TVM_FFI_CHECK_EQ(ndim, kUnknownNDim, ValueError) << "Cannot both specify shape and ndim"; return TensorStructInfo(shape.value(), dtype.value_or(DataType::Void()), vdevice, span); } else { return TensorStructInfo(dtype.value_or(DataType::Void()), ndim, vdevice, span); @@ -209,21 +209,21 @@ TVM_FFI_STATIC_INIT_BLOCK() { [](ffi::Array params, StructInfo ret, bool purity, Span span) { return FuncStructInfo(params, ret, purity, span); }) - .def("relax.FuncStructInfoOpaqueFunc", - [](ffi::Optional ret, ffi::Optional derive_func, - bool purity, Span span) { - if (derive_func.defined()) { - ICHECK(!ret.defined()) << "ValueError: Cannot specify both ret and derive_func"; - return FuncStructInfo::OpaqueFunc(derive_func.value(), purity, span); - } else { - return FuncStructInfo::OpaqueFunc(ret.value_or(ObjectStructInfo()), purity, span); - } - }); + .def("relax.FuncStructInfoOpaqueFunc", [](ffi::Optional ret, + ffi::Optional derive_func, + bool purity, Span span) { + if (derive_func.defined()) { + TVM_FFI_CHECK(!ret.defined(), ValueError) << "Cannot specify both ret and derive_func"; + return FuncStructInfo::OpaqueFunc(derive_func.value(), purity, span); + } else { + return FuncStructInfo::OpaqueFunc(ret.value_or(ObjectStructInfo()), purity, span); + } + }); } // Helper functions void UpdateStructInfo(Expr expr, StructInfo struct_info) { - ICHECK(!expr->struct_info_.defined()) + TVM_FFI_ICHECK(!expr->struct_info_.defined()) << "To ensure idempotency, " << "the expression passed to UpdateStructInfo " << "must not have any prior StructInfo. " diff --git a/src/relax/ir/struct_info_functor.cc b/src/relax/ir/struct_info_functor.cc index 58df3c24ff8e..bab929c39184 100644 --- a/src/relax/ir/struct_info_functor.cc +++ b/src/relax/ir/struct_info_functor.cc @@ -146,7 +146,7 @@ StructInfo StructInfoMutator::VisitStructInfo_(const FuncStructInfoNode* op) { if (params.same_as(op->params) && ret.same_as(op->ret)) { return ffi::GetRef(op); } else { - ICHECK(ret.defined()) << "FuncStructInfo that contains params must contain ret"; + TVM_FFI_ICHECK(ret.defined()) << "FuncStructInfo that contains params must contain ret"; return FuncStructInfo(params.value(), ret, op->purity, op->span); } } diff --git a/src/relax/ir/transform.cc b/src/relax/ir/transform.cc index 281a60375b92..d787f906d2ca 100644 --- a/src/relax/ir/transform.cc +++ b/src/relax/ir/transform.cc @@ -118,12 +118,12 @@ IRModule FunctionPassNode::operator()(IRModule mod, const PassContext& pass_ctx) pass_ctx->diag_ctx = previous; } - ICHECK(pass_ctx->diag_ctx) + TVM_FFI_ICHECK(pass_ctx->diag_ctx) << "The diagnostic context was set at the top of this block this is a bug."; const PassInfo& pass_info = Info(); - ICHECK(mod.defined()); + TVM_FFI_ICHECK(mod.defined()); VLOG_CONTEXT << pass_info->name; VLOG(0) << "Executing function pass with opt level: " << pass_info->opt_level; @@ -145,7 +145,7 @@ IRModule FunctionPassNode::operator()(IRModule mod, const PassContext& pass_ctx) updated_mod->Add(pair.first, pair.second, true); } - ICHECK(pass_ctx->diag_ctx) + TVM_FFI_ICHECK(pass_ctx->diag_ctx) << "The diagnostic context was set at the top of this block, this is a bug."; pass_ctx->diag_ctx.value().Render(); @@ -261,19 +261,19 @@ class DataflowBlockMutator : public ExprMutator { for (const tir::VarNode* var : collected_vars) { if (symbolic_vars.count(var->name_hint) > 0) { tir::Var old_var = symbolic_vars[var->name_hint]; - ICHECK(var == old_var.get()) + TVM_FFI_ICHECK(var == old_var.get()) << "Error: DataflowBlock Pass should not rewrite any Symbolic Var."; symbolic_vars.erase(var->name_hint); } } } if (!var.as() && global_scope_vars.count(var->name_hint()) > 0) { - ICHECK(var.same_as(global_scope_vars[var->name_hint()])) + TVM_FFI_ICHECK(var.same_as(global_scope_vars[var->name_hint()])) << "Error: DataflowBlock Pass should not rewrite any GlobalScope Var."; global_scope_vars.erase(var->name_hint()); } } - ICHECK(global_scope_vars.empty() && symbolic_vars.empty()) + TVM_FFI_ICHECK(global_scope_vars.empty() && symbolic_vars.empty()) << "Error: DataflowBlock Pass should not delete any GlobalScope/Symbolic Var."; return updated_block; @@ -339,12 +339,12 @@ IRModule DataflowBlockPassNode::operator()(IRModule mod, const PassContext& pass pass_ctx->diag_ctx = previous; } - ICHECK(pass_ctx->diag_ctx) + TVM_FFI_ICHECK(pass_ctx->diag_ctx) << "The diagnostic context was set at the top of this block, this is a bug."; const PassInfo& pass_info = Info(); - ICHECK(mod.defined()); + TVM_FFI_ICHECK(mod.defined()); VLOG_CONTEXT << pass_info->name; VLOG(0) << "Executing DataflowBlock pass with opt level: " << pass_info->opt_level; @@ -367,7 +367,7 @@ IRModule DataflowBlockPassNode::operator()(IRModule mod, const PassContext& pass updated_mod->Add(pair.first, pair.second, true); } - ICHECK(pass_ctx->diag_ctx) + TVM_FFI_ICHECK(pass_ctx->diag_ctx) << "The diagnostic context was set at the top of this block this is a bug."; pass_ctx->diag_ctx.value().Render(); diff --git a/src/relax/op/ccl/ccl.cc b/src/relax/op/ccl/ccl.cc index 29036f42f846..8818bd8e3179 100644 --- a/src/relax/op/ccl/ccl.cc +++ b/src/relax/op/ccl/ccl.cc @@ -148,7 +148,8 @@ StructInfo InferStructInfoScatter(const Call& call, const BlockBuilder& ctx) { arith::Analyzer* analyzer = ctx->GetAnalyzer(); auto input_shape = input_sinfo->GetShape(); - CHECK(input_shape.defined()) << "input tensor of scatter_from_worker0 should have defined shape."; + TVM_FFI_ICHECK(input_shape.defined()) + << "input tensor of scatter_from_worker0 should have defined shape."; if (analyzer->CanProve(floormod(input_shape.value()[attrs->axis], PrimExpr(num_workers)) != 0)) { ctx->ReportFatal(Diagnostic::Error(call) diff --git a/src/relax/op/distributed/binary.h b/src/relax/op/distributed/binary.h index 127dec433afa..4dae46c578e3 100644 --- a/src/relax/op/distributed/binary.h +++ b/src/relax/op/distributed/binary.h @@ -46,7 +46,7 @@ StructInfo InferDistStructInfoBroadcast(const Call& call, const BlockBuilder& ct DataType output_dtype = f_compute_out_dtype(call, ctx, x1_sinfo, x2_sinfo); // ndims - ICHECK(!x1_sinfo->IsUnknownNdim() && !x2_sinfo->IsUnknownNdim()) + TVM_FFI_ICHECK(!x1_sinfo->IsUnknownNdim() && !x2_sinfo->IsUnknownNdim()) << "Unknown ndim is not supported for distributed operators."; int output_ndim = std::max(x1_sinfo->ndim, x2_sinfo->ndim); @@ -61,7 +61,7 @@ StructInfo InferDistStructInfoBroadcast(const Call& call, const BlockBuilder& ct if (!output_shape.defined()) { output_tensor_sinfo = TensorStructInfo(output_dtype, /*ndim=*/output_ndim); } else { - ICHECK_EQ(static_cast(output_shape.value().size()), output_ndim); + TVM_FFI_ICHECK_EQ(static_cast(output_shape.value().size()), output_ndim); output_tensor_sinfo = TensorStructInfo(ShapeExpr(output_shape.value()), output_dtype); } } else { diff --git a/src/relax/op/distributed/ccl.cc b/src/relax/op/distributed/ccl.cc index 6ba63986980e..cb48ff38aa4f 100644 --- a/src/relax/op/distributed/ccl.cc +++ b/src/relax/op/distributed/ccl.cc @@ -26,7 +26,7 @@ namespace distributed { StructInfo InferDistStructInfoAllReduce(const Call& call, const BlockBuilder& ctx) { ffi::Array input_dtensor_sinfos = GetInputDTensorStructInfo(call, ctx); - ICHECK(input_dtensor_sinfos.size() == 1); + TVM_FFI_ICHECK(input_dtensor_sinfos.size() == 1); DTensorStructInfo input_dtensor_sinfo = input_dtensor_sinfos[0]; TensorStructInfo tensor_sinfo = input_dtensor_sinfo->tensor_sinfo; DeviceMesh device_mesh = input_dtensor_sinfo->device_mesh; diff --git a/src/relax/op/distributed/distributed.cc b/src/relax/op/distributed/distributed.cc index 636891366194..242c4af4f82a 100644 --- a/src/relax/op/distributed/distributed.cc +++ b/src/relax/op/distributed/distributed.cc @@ -87,7 +87,7 @@ TVM_FFI_STATIC_INIT_BLOCK() { StructInfo InferDistStructInfoRedistribute(const Call& call, const BlockBuilder& ctx) { const auto* attrs = call->attrs.as(); const auto* sinfo = GetStructInfoAs(call->args[0]); - ICHECK(sinfo); + TVM_FFI_ICHECK(sinfo); return distributed::DTensorStructInfo(sinfo->tensor_sinfo, attrs->device_mesh, attrs->placement); } @@ -102,7 +102,7 @@ StructInfo InferStructInfoCallTIRLocalView(const Call& call, const BlockBuilder& ctx->ReportFatal(Diagnostic::Error(call) << "sinfo_args should have exactly 1 output struct info."); } - CHECK(call->args[0]->IsInstance()) + TVM_FFI_ICHECK(call->args[0]->IsInstance()) << "call_tir_local_view expects the first argument to be a GlobalVar referring to a TIR " "PrimFunc. " << "However, gets " << call->args[0]; @@ -124,7 +124,7 @@ Expr MakeCallTIRLocalView(Expr func, Tuple args, ffi::Optional packed_ints) { for (const distributed::DTensorStructInfo& sinfo : out_sinfo_list) { const auto* shape = sinfo->tensor_sinfo->shape.as(); - CHECK(shape != nullptr) + TVM_FFI_ICHECK(shape != nullptr) << "out_sinfo of call_tir_local_view should have defined ShapeExpr as shape. " "However, one given structure info is " << sinfo; @@ -162,7 +162,7 @@ StructInfo InferStructInfoRtoS(const Call& call, const BlockBuilder& ctx) { arith::Analyzer* analyzer = ctx->GetAnalyzer(); auto input_shape = input_sinfo->GetShape(); - CHECK(input_shape.defined()) + TVM_FFI_ICHECK(input_shape.defined()) << "input tensor of redistribute_replica_to_shard should have defined shape."; if (analyzer->CanProve(floormod(input_shape.value()[attrs->axis], PrimExpr(num_workers))) != 0) { @@ -183,14 +183,14 @@ StructInfo InferStructInfoRtoS(const Call& call, const BlockBuilder& ctx) { StructInfo InferDistStructInfoRtoS(const Call& call, const BlockBuilder& ctx) { using namespace distributed; ffi::Array input_dtensor_sinfos = GetInputDTensorStructInfo(call, ctx); - ICHECK(input_dtensor_sinfos.size() == 1); + TVM_FFI_ICHECK(input_dtensor_sinfos.size() == 1); DTensorStructInfo input_dtensor_sinfo = input_dtensor_sinfos[0]; TensorStructInfo tensor_sinfo = input_dtensor_sinfo->tensor_sinfo; const auto* attrs = call->attrs.as(); int num_workers = attrs->num_workers; arith::Analyzer* analyzer = ctx->GetAnalyzer(); auto input_shape = tensor_sinfo->GetShape(); - CHECK(input_shape.defined()) + TVM_FFI_ICHECK(input_shape.defined()) << "input tensor of redistribute_replica_to_shard should have defined shape."; if (analyzer->CanProve(floormod(input_shape.value()[attrs->axis], PrimExpr(num_workers))) != 0) { @@ -205,8 +205,8 @@ StructInfo InferDistStructInfoRtoS(const Call& call, const BlockBuilder& ctx) { DeviceMesh device_mesh = input_dtensor_sinfo->device_mesh; // FIXME: this is a hack where there's only 1d mesh - ICHECK(device_mesh->shape.size() == 1); - ICHECK(input_dtensor_sinfo->placement->dim_specs[0]->kind == PlacementSpecKind::kReplica); + TVM_FFI_ICHECK(device_mesh->shape.size() == 1); + TVM_FFI_ICHECK(input_dtensor_sinfo->placement->dim_specs[0]->kind == PlacementSpecKind::kReplica); return DTensorStructInfo(tensor_sinfo, device_mesh, Placement::FromText("S[" + std::to_string(attrs->axis) + "]")); } diff --git a/src/relax/op/distributed/linear_algebra.cc b/src/relax/op/distributed/linear_algebra.cc index 8fc9cd58d1fc..aeee041afb40 100644 --- a/src/relax/op/distributed/linear_algebra.cc +++ b/src/relax/op/distributed/linear_algebra.cc @@ -74,7 +74,7 @@ StructInfo InferDistStructInfoMatmul(const Call& call, const BlockBuilder& ctx) x2_shape->values.end() - 2 + x2_appended}; ffi::Optional> output_shape_prefix = InferBinaryBroadcastShape(call, ctx, x1_shape_prefix, x2_shape_prefix); - ICHECK(output_shape_prefix.defined()) << "Failed to infer output shape of Matmul"; + TVM_FFI_ICHECK(output_shape_prefix.defined()) << "Failed to infer output shape of Matmul"; arith::Analyzer* analyzer = ctx->GetAnalyzer(); PrimExpr x1_reduction_length = x1_shape->values[x1_sinfo->ndim - 1]; PrimExpr x2_reduction_length = x2_shape->values[x2_ndim - 2]; @@ -92,7 +92,7 @@ StructInfo InferDistStructInfoMatmul(const Call& call, const BlockBuilder& ctx) if (!x2_appended) { output_shape.push_back(x2_shape->values[x2_ndim - 1]); } - ICHECK_EQ(static_cast(output_shape.size()), output_ndim); + TVM_FFI_ICHECK_EQ(static_cast(output_shape.size()), output_ndim); TensorStructInfo output_tensor_sinfo(ShapeExpr(output_shape), out_dtype); return InferShardingSpec(call, ctx, output_tensor_sinfo, distributed::BuildAxisGraphMatmul); } diff --git a/src/relax/op/distributed/manipulate.cc b/src/relax/op/distributed/manipulate.cc index edd5fa7ee7f9..4a629f145f21 100644 --- a/src/relax/op/distributed/manipulate.cc +++ b/src/relax/op/distributed/manipulate.cc @@ -105,7 +105,7 @@ StructInfo InferDistStructInfoReshape(const Call& call, const BlockBuilder& ctx) ffi::Optional> old_shape_values; if (data_sinfo->shape.defined()) { const auto* old_shape_sinfo = GetStructInfoAs(data_sinfo->shape.value()); - ICHECK_NOTNULL(old_shape_sinfo); + TVM_FFI_ICHECK_NOTNULL(old_shape_sinfo); old_shape_values = old_shape_sinfo->values; } diff --git a/src/relax/op/distributed/nn.cc b/src/relax/op/distributed/nn.cc index b020d7902f9b..f8200c03df15 100644 --- a/src/relax/op/distributed/nn.cc +++ b/src/relax/op/distributed/nn.cc @@ -26,7 +26,7 @@ namespace distributed { StructInfo InferDistStructInfoSoftmax(const Call& call, const BlockBuilder& ctx) { ffi::Array input_dtensor_sinfos = GetInputDTensorStructInfo(call, ctx); - ICHECK(input_dtensor_sinfos.size() == 1); + TVM_FFI_ICHECK(input_dtensor_sinfos.size() == 1); TensorStructInfo input_tensor_sinfo = input_dtensor_sinfos[0]->tensor_sinfo; if (input_tensor_sinfo->IsUnknownNdim()) { diff --git a/src/relax/op/distributed/op.cc b/src/relax/op/distributed/op.cc index ef780c6df8e0..11feb3e6b8f5 100644 --- a/src/relax/op/distributed/op.cc +++ b/src/relax/op/distributed/op.cc @@ -28,7 +28,7 @@ StructInfo InferDistStructInfoCallTIR(const Call& call, const BlockBuilder& ctx) ctx->ReportFatal(Diagnostic::Error(call) << "sinfo_args should have exact 1 output struct info."); } - CHECK(call->args[0]->IsInstance()) + TVM_FFI_ICHECK(call->args[0]->IsInstance()) << "call_tir expects the first argument to be a GlobalVar referring to a TIR PrimFunc. " << "However, gets " << call->args[0]; return call->sinfo_args[0]; diff --git a/src/relax/op/distributed/statistical.cc b/src/relax/op/distributed/statistical.cc index 44ee90e78976..8c1cbc341275 100644 --- a/src/relax/op/distributed/statistical.cc +++ b/src/relax/op/distributed/statistical.cc @@ -45,7 +45,7 @@ StructInfo InferDistStructInfoStatistical(const Call& call, const BlockBuilder& ctx->ReportFatal(Diagnostic::Error(call) << "Input of distributed operator must be known ndim"); } else { out_ndim = data_sinfo->ndim - axes.size(); - ICHECK_GE(out_ndim, 0); + TVM_FFI_ICHECK_GE(out_ndim, 0); } // The inference rule for reduction operator output shapes: @@ -70,7 +70,7 @@ StructInfo InferDistStructInfoStatistical(const Call& call, const BlockBuilder& out_shape.push_back(IntImm(DataType::Int(64), /*value=*/1)); } } - ICHECK_EQ(static_cast(out_shape.size()), out_ndim); + TVM_FFI_ICHECK_EQ(static_cast(out_shape.size()), out_ndim); TensorStructInfo output_tensor_sinfo = TensorStructInfo(ShapeExpr(out_shape), data_sinfo->dtype); return InferShardingSpec(call, ctx, output_tensor_sinfo, distributed::BuildAxisGraphReduce); diff --git a/src/relax/op/distributed/unary.h b/src/relax/op/distributed/unary.h index 727707a98525..4e00f4ca3461 100644 --- a/src/relax/op/distributed/unary.h +++ b/src/relax/op/distributed/unary.h @@ -36,7 +36,7 @@ StructInfo InferDistStructInfoUnary(const Call& call, const BlockBuilder& ctx, FType f_compute_out_dtype) { ffi::Array input_dtensor_sinfos = GetInputDTensorStructInfo(call, ctx); - ICHECK(input_dtensor_sinfos.size() == 1); + TVM_FFI_ICHECK(input_dtensor_sinfos.size() == 1); distributed::DTensorStructInfo input_dtensor_sinfo = input_dtensor_sinfos[0]; TensorStructInfo input_tensor_sinfo = input_dtensor_sinfo->tensor_sinfo; diff --git a/src/relax/op/distributed/utils.cc b/src/relax/op/distributed/utils.cc index ffa7dbfa3085..d8a23da3825b 100644 --- a/src/relax/op/distributed/utils.cc +++ b/src/relax/op/distributed/utils.cc @@ -45,8 +45,8 @@ StructInfo InferShardingSpec(const Call& call, const BlockBuilder& ctx, ffi::Array input_dtensor_sinfos = GetInputDTensorStructInfo(call, ctx); for (int i = 1; i < static_cast(input_dtensor_sinfos.size()); i++) { - ICHECK(StructuralEqual()(input_dtensor_sinfos[0]->device_mesh, - input_dtensor_sinfos[i]->device_mesh)); + TVM_FFI_ICHECK(StructuralEqual()(input_dtensor_sinfos[0]->device_mesh, + input_dtensor_sinfos[i]->device_mesh)); } distributed::DeviceMesh device_mesh = input_dtensor_sinfos[0]->device_mesh; Var output_var("output", orig_output_sinfo); @@ -72,7 +72,7 @@ StructInfo InferShardingSpec(const Call& call, const BlockBuilder& ctx, orig_output_tensor_sinfos.push_back(ffi::GetRef(tensor_sinfo)); } else { const auto* tuple_sinfo = orig_output_sinfo.as(); - ICHECK(tuple_sinfo); + TVM_FFI_ICHECK(tuple_sinfo); for (const auto& sinfo : tuple_sinfo->fields) { orig_output_tensor_sinfos.push_back(Downcast(sinfo)); } diff --git a/src/relax/op/image/resize.cc b/src/relax/op/image/resize.cc index 59d845d867f6..15fdcf3eb99a 100644 --- a/src/relax/op/image/resize.cc +++ b/src/relax/op/image/resize.cc @@ -114,7 +114,7 @@ InferLayoutOutput InferLayoutResize2d( const VarLayoutMap& var_layout_map) { const auto& it = desired_layouts.find("relax.image.resize2d"); const auto* attrs = call->attrs.as(); - ICHECK(attrs) << "Invalid Call"; + TVM_FFI_ICHECK(attrs) << "Invalid Call"; LayoutDecision data_layout; ObjectPtr new_attrs = ffi::make_object(*attrs); @@ -122,7 +122,8 @@ InferLayoutOutput InferLayoutResize2d( if (it != desired_layouts.end()) { // We have a desired layout for resize2d. Layout desired_data_layout = (*it).second[0]; - ICHECK_EQ(desired_data_layout.ndim(), desired_data_layout.ndim_primal()) << "Axis swap only"; + TVM_FFI_ICHECK_EQ(desired_data_layout.ndim(), desired_data_layout.ndim_primal()) + << "Axis swap only"; data_layout = TransposeLike(InitialLayout(4), attrs->layout, desired_data_layout); new_attrs->layout = (*it).second[0]; } else { diff --git a/src/relax/op/memory/view.cc b/src/relax/op/memory/view.cc index 04a845bd816d..c6b08c8aef98 100644 --- a/src/relax/op/memory/view.cc +++ b/src/relax/op/memory/view.cc @@ -64,9 +64,9 @@ StructInfo InferStructInfoView(const Call& call, const BlockBuilder& ctx) { if (auto opt = sinfo.as()) { return opt.value(); } else { - LOG(FATAL) << "TypeError: " - << "Operator " << call->op << " expects first argument to be a tensor, " - << "but received " << arg_data << " with type " << sinfo; + TVM_FFI_THROW(TypeError) << "Operator " << call->op + << " expects first argument to be a tensor, " + << "but received " << arg_data << " with type " << sinfo; } }(); auto view_shape_sinfo = [&]() -> const ShapeStructInfoNode* { @@ -79,10 +79,10 @@ StructInfo InferStructInfoView(const Call& call, const BlockBuilder& ctx) { // The `R.view` operation returns a different shape. return ptr; } else { - LOG(FATAL) << "TypeError: " - << "Operator " << call->op << " expects second argument to be a ShapeExpr, " - << "or a void-type (empty relax tuple), " - << "but received " << arg_shape << " with type " << sinfo; + TVM_FFI_THROW(TypeError) << "Operator " << call->op + << " expects second argument to be a ShapeExpr, " + << "or a void-type (empty relax tuple), " + << "but received " << arg_shape << " with type " << sinfo; } }(); @@ -117,10 +117,9 @@ StructInfo InferStructInfoView(const Call& call, const BlockBuilder& ctx) { // being changed into. return DataType::Void(); } else { - LOG(FATAL) << "TypeError: " - << "Operator " << call->op - << " expects the dtype argument to be a relax::DataTypeImm, " - << "but received " << arg_dtype << " with type " << sinfo; + TVM_FFI_THROW(TypeError) << "Operator " << call->op + << " expects the dtype argument to be a relax::DataTypeImm, " + << "but received " << arg_dtype << " with type " << sinfo; } }(); @@ -131,8 +130,7 @@ StructInfo InferStructInfoView(const Call& call, const BlockBuilder& ctx) { // No byte offset is specified, so no change is applied. return IntImm(DataType::Int(64), 0); } else if (auto prim_sinfo = sinfo.as()) { - CHECK_EQ(prim_sinfo->dtype, DataType::Int(64)) - << "TypeError: " + TVM_FFI_CHECK_EQ(prim_sinfo->dtype, DataType::Int(64), TypeError) << "Operator " << call->op << " expects the relative_byte_offset to be a 64-bit integer, but received " << arg_relative_byte_offset << ", which has type " << sinfo; @@ -145,11 +143,12 @@ StructInfo InferStructInfoView(const Call& call, const BlockBuilder& ctx) { return std::nullopt; } } else { - LOG(FATAL) << "TypeError: " - << "Operator " << call->op << " expects the relative_byte_offset argument " - << "to be a Relax PrimValue. " - << "However, expression " << call << " provides relative_byte_offset of " - << arg_relative_byte_offset << ", which has type " << sinfo; + TVM_FFI_THROW(TypeError) << "Operator " << call->op + << " expects the relative_byte_offset argument " + << "to be a Relax PrimValue. " + << "However, expression " << call + << " provides relative_byte_offset of " << arg_relative_byte_offset + << ", which has type " << sinfo; } }(); @@ -214,15 +213,15 @@ StructInfo InferStructInfoView(const Call& call, const BlockBuilder& ctx) { if (ctx->GetAnalyzer()->CanProve(output_nbytes + view_relative_byte_offset.value() > input_nbytes)) { - LOG(FATAL) << "ValueError: " - << "Views into an array must not exceed the bounds of the array being viewed. " - << "However, expression " << call << " attempted to create view of type " - << TensorStructInfo(ShapeExpr(output_shape.value()), output_dtype) - << " with relative byte offset " << view_relative_byte_offset - << ", viewing into the array " << arg_data << " of type " << data_sinfo << ". " - << "The end of the view would occur at byte " << view_end - << ", relative to the start of array " << arg_data << ", but " << arg_data - << " is only " << input_nbytes << " long."; + TVM_FFI_THROW(ValueError) + << "Views into an array must not exceed the bounds of the array being viewed. " + << "However, expression " << call << " attempted to create view of type " + << TensorStructInfo(ShapeExpr(output_shape.value()), output_dtype) + << " with relative byte offset " << view_relative_byte_offset + << ", viewing into the array " << arg_data << " of type " << data_sinfo << ". " + << "The end of the view would occur at byte " << view_end + << ", relative to the start of array " << arg_data << ", but " << arg_data << " is only " + << input_nbytes << " long."; } } else if (input_nelements && output_nelements && input_element_size && output_element_size) { @@ -236,13 +235,13 @@ StructInfo InferStructInfoView(const Call& call, const BlockBuilder& ctx) { PrimExpr output_nbytes = output_nelements.value() * output_element_size.value(); if (ctx->GetAnalyzer()->CanProve(output_nbytes > input_nbytes)) { - LOG(FATAL) << "ValueError: " - << "Views into an array must not exceed the bounds of the array being viewed. " - << "However, expression " << call << " attempted to create view of type " - << TensorStructInfo(ShapeExpr(output_shape.value()), output_dtype) - << " from input array of type " << data_sinfo << ". " - << "This view would increase the size from " << output_nbytes << " bytes to " - << output_nbytes << " bytes."; + TVM_FFI_THROW(ValueError) + << "Views into an array must not exceed the bounds of the array being viewed. " + << "However, expression " << call << " attempted to create view of type " + << TensorStructInfo(ShapeExpr(output_shape.value()), output_dtype) + << " from input array of type " << data_sinfo << ". " + << "This view would increase the size from " << output_nbytes << " bytes to " + << output_nbytes << " bytes."; } } else if (input_element_size && output_element_size && !view_shape_sinfo) { @@ -252,8 +251,8 @@ StructInfo InferStructInfoView(const Call& call, const BlockBuilder& ctx) { // elements, an increase to the per-element size would cause the // view to be larger than the original array. - CHECK_GE(input_element_size.value()->value, output_element_size.value()->value) - << "ValueError: " + TVM_FFI_CHECK_GE(input_element_size.value()->value, output_element_size.value()->value, + ValueError) << "Operator " << call->op << " may not produce a view that exceeds the bounds of the original array. " << "In expression " << call << " the data type is changed from " << data_sinfo->dtype @@ -269,23 +268,23 @@ StructInfo InferStructInfoView(const Call& call, const BlockBuilder& ctx) { // individual element. if (ctx->GetAnalyzer()->CanProve(output_nelements.value() > input_nelements.value())) { - LOG(FATAL) << "ValueError: " - << "Views into an array must not exceed the bounds of the array being viewed. " - << "However, expression " << call << " attempted to view array " << arg_data - << " (shape = " << input_shape << ", " << input_nelements << " elements) as shape " - << output_shape << " with " << output_nelements << " elements."; + TVM_FFI_THROW(ValueError) + << "Views into an array must not exceed the bounds of the array being viewed. " + << "However, expression " << call << " attempted to view array " << arg_data + << " (shape = " << input_shape << ", " << input_nelements << " elements) as shape " + << output_shape << " with " << output_nelements << " elements."; } } else if (view_relative_byte_offset && !view_shape_sinfo && !view_dtype) { // The byte_offset is being updated, but neither the shape nor the // dtype is changing. Any non-zero offset will cause the view to // overrun the bounds of the original array. if (ctx->GetAnalyzer()->CanProve(view_relative_byte_offset.value() > 0)) { - LOG(FATAL) << "ValueError: " - << "Views into an array must not exceed the bounds of the array being viewed. " - << "However, expression " << call << " attempted to offset the view by " - << view_relative_byte_offset << " bytes, " - << "without reducing either the number of elements in the view " - << "or the size of each element."; + TVM_FFI_THROW(ValueError) + << "Views into an array must not exceed the bounds of the array being viewed. " + << "However, expression " << call << " attempted to offset the view by " + << view_relative_byte_offset << " bytes, " + << "without reducing either the number of elements in the view " + << "or the size of each element."; } } @@ -320,7 +319,7 @@ Expr LowerBuiltinView(const BlockBuilder& bb, const Call& call) { if (HasVoidStructInfo(shape)) { auto data_shape = data->struct_info_.as().value()->GetShape(); - CHECK(data_shape.defined()) + TVM_FFI_ICHECK(data_shape.defined()) << "Legalization of " << call->op << " requires that either the output shape be explicitly specified, " << "or the input shape is known. " @@ -331,7 +330,7 @@ Expr LowerBuiltinView(const BlockBuilder& bb, const Call& call) { if (HasVoidStructInfo(dtype)) { auto data_dtype = data->struct_info_.as().value()->dtype; - CHECK(!data_dtype.is_void()) + TVM_FFI_ICHECK(!data_dtype.is_void()) << "Legalization of " << call->op << " requires that either the output dtype be explicitly specified, " << "or the input dtype is known. " diff --git a/src/relax/op/nn/convolution.cc b/src/relax/op/nn/convolution.cc index 5368db79d262..648fc50c89c2 100644 --- a/src/relax/op/nn/convolution.cc +++ b/src/relax/op/nn/convolution.cc @@ -47,12 +47,13 @@ Expr conv1d(Expr data, Expr weight, ffi::Array strides, ffi::Array out_dtype) { padding = GetCompletePadding1D(std::move(padding)); - CHECK_GT(groups, 0) << "The number of groups in convolution is expected to be positive. However, " - "the given number of groups is " - << groups; - CHECK_EQ(strides.size(), 1) + TVM_FFI_ICHECK_GT(groups, 0) + << "The number of groups in convolution is expected to be positive. However, " + "the given number of groups is " + << groups; + TVM_FFI_ICHECK_EQ(strides.size(), 1) << "The input strides length is expected to be 1. However, the given strides is " << strides; - CHECK_EQ(dilation.size(), 1) + TVM_FFI_ICHECK_EQ(dilation.size(), 1) << "The input dilation length is expected to be 1. However, the given dilation is " << dilation; return MakeConv(std::move(data), std::move(weight), std::move(strides), @@ -144,7 +145,7 @@ InferLayoutOutput InferLayoutConv1d( const VarLayoutMap& var_layout_map) { const auto& it = desired_layouts.find("relax.nn.conv1d"); const auto* attrs = call->attrs.as(); - ICHECK(attrs) << "Invalid Call"; + TVM_FFI_ICHECK(attrs) << "Invalid Call"; LayoutDecision data_layout, weight_layout, output_layout; ObjectPtr new_attrs = ffi::make_object(*attrs); @@ -154,10 +155,11 @@ InferLayoutOutput InferLayoutConv1d( Layout desired_data_layout = (*it).second[0]; Layout desired_weight_layout = (*it).second[1]; Layout desired_output_layout = (*it).second.size() == 3 ? (*it).second[2] : (*it).second[0]; - ICHECK_EQ(desired_data_layout.ndim(), desired_data_layout.ndim_primal()) << "Axis swap only"; - ICHECK_EQ(desired_weight_layout.ndim(), desired_weight_layout.ndim_primal()) + TVM_FFI_ICHECK_EQ(desired_data_layout.ndim(), desired_data_layout.ndim_primal()) << "Axis swap only"; - ICHECK_EQ(desired_output_layout.ndim(), desired_output_layout.ndim_primal()) + TVM_FFI_ICHECK_EQ(desired_weight_layout.ndim(), desired_weight_layout.ndim_primal()) + << "Axis swap only"; + TVM_FFI_ICHECK_EQ(desired_output_layout.ndim(), desired_output_layout.ndim_primal()) << "Axis swap only"; data_layout = TransposeLike(InitialLayout(3), attrs->data_layout, desired_data_layout); weight_layout = TransposeLike(InitialLayout(3), attrs->kernel_layout, desired_weight_layout); @@ -214,12 +216,13 @@ Expr conv2d(Expr data, Expr weight, ffi::Array strides, ffi::Array(std::move(data), std::move(weight), std::move(strides), @@ -316,7 +319,7 @@ InferLayoutOutput InferLayoutConv2d( const VarLayoutMap& var_layout_map) { const auto& it = desired_layouts.find("relax.nn.conv2d"); const auto* attrs = call->attrs.as(); - ICHECK(attrs) << "Invalid Call"; + TVM_FFI_ICHECK(attrs) << "Invalid Call"; LayoutDecision data_layout, weight_layout, output_layout; data_layout = GetLayoutDecision(var_layout_map, call->args[0]); @@ -423,12 +426,13 @@ Expr conv3d(Expr data, Expr weight, ffi::Array strides, ffi::Array(std::move(data), std::move(weight), std::move(strides), @@ -531,7 +535,7 @@ InferLayoutOutput InferLayoutConv3d( const VarLayoutMap& var_layout_map) { const auto& it = desired_layouts.find("relax.nn.conv3d"); const auto* attrs = call->attrs.as(); - ICHECK(attrs) << "Invalid Call"; + TVM_FFI_ICHECK(attrs) << "Invalid Call"; LayoutDecision data_layout, weight_layout, output_layout; ObjectPtr new_attrs = ffi::make_object(*attrs); @@ -541,10 +545,11 @@ InferLayoutOutput InferLayoutConv3d( Layout desired_data_layout = (*it).second[0]; Layout desired_weight_layout = (*it).second[1]; Layout desired_output_layout = (*it).second.size() == 3 ? (*it).second[2] : (*it).second[0]; - ICHECK_EQ(desired_data_layout.ndim(), desired_data_layout.ndim_primal()) << "Axis swap only"; - ICHECK_EQ(desired_weight_layout.ndim(), desired_weight_layout.ndim_primal()) + TVM_FFI_ICHECK_EQ(desired_data_layout.ndim(), desired_data_layout.ndim_primal()) + << "Axis swap only"; + TVM_FFI_ICHECK_EQ(desired_weight_layout.ndim(), desired_weight_layout.ndim_primal()) << "Axis swap only"; - ICHECK_EQ(desired_output_layout.ndim(), desired_output_layout.ndim_primal()) + TVM_FFI_ICHECK_EQ(desired_output_layout.ndim(), desired_output_layout.ndim_primal()) << "Axis swap only"; data_layout = TransposeLike(InitialLayout(5), attrs->data_layout, desired_data_layout); weight_layout = TransposeLike(InitialLayout(5), attrs->kernel_layout, desired_weight_layout); @@ -594,15 +599,17 @@ Expr conv1d_transpose(Expr data, Expr weight, ffi::Array strides, ffi::Optional out_dtype) { padding = GetCompletePadding1D(std::move(padding)); - CHECK_GT(groups, 0) << "The number of groups in convolution is expected to be positive. However, " - "the given number of groups is " - << groups; - CHECK_EQ(output_padding.size(), 1) << "The input output_padding length is expected to be 1. " - "However, the given output_padding is " - << output_padding; - CHECK_EQ(strides.size(), 1) + TVM_FFI_ICHECK_GT(groups, 0) + << "The number of groups in convolution is expected to be positive. However, " + "the given number of groups is " + << groups; + TVM_FFI_ICHECK_EQ(output_padding.size(), 1) + << "The input output_padding length is expected to be 1. " + "However, the given output_padding is " + << output_padding; + TVM_FFI_ICHECK_EQ(strides.size(), 1) << "The input strides length is expected to be 1. However, the given strides is " << strides; - CHECK_EQ(dilation.size(), 1) + TVM_FFI_ICHECK_EQ(dilation.size(), 1) << "The input dilation length is expected to be 1. However, the given dilation is " << dilation; @@ -720,10 +727,11 @@ InferLayoutOutput InferLayoutConv1dTranspose( Layout desired_data_layout = (*it).second[0]; Layout desired_weight_layout = (*it).second[1]; Layout desired_output_layout = (*it).second.size() == 3 ? (*it).second[2] : (*it).second[0]; - ICHECK_EQ(desired_data_layout.ndim(), desired_data_layout.ndim_primal()) << "Axis swap only"; - ICHECK_EQ(desired_weight_layout.ndim(), desired_weight_layout.ndim_primal()) + TVM_FFI_ICHECK_EQ(desired_data_layout.ndim(), desired_data_layout.ndim_primal()) + << "Axis swap only"; + TVM_FFI_ICHECK_EQ(desired_weight_layout.ndim(), desired_weight_layout.ndim_primal()) << "Axis swap only"; - ICHECK_EQ(desired_output_layout.ndim(), desired_output_layout.ndim_primal()) + TVM_FFI_ICHECK_EQ(desired_output_layout.ndim(), desired_output_layout.ndim_primal()) << "Axis swap only"; data_layout = TransposeLike(InitialLayout(3), attrs->data_layout, desired_data_layout); weight_layout = TransposeLike(InitialLayout(3), attrs->kernel_layout, desired_weight_layout); @@ -784,15 +792,17 @@ Expr conv2d_transpose(Expr data, Expr weight, ffi::Array strides, dilation.push_back(dilation[0]); } - CHECK_GT(groups, 0) << "The number of groups in convolution is expected to be positive. However, " - "the given number of groups is " - << groups; - CHECK_EQ(output_padding.size(), 2) << "The input output_padding length is expected to be 2. " - "However, the given output_padding is " - << output_padding; - CHECK_EQ(strides.size(), 2) + TVM_FFI_ICHECK_GT(groups, 0) + << "The number of groups in convolution is expected to be positive. However, " + "the given number of groups is " + << groups; + TVM_FFI_ICHECK_EQ(output_padding.size(), 2) + << "The input output_padding length is expected to be 2. " + "However, the given output_padding is " + << output_padding; + TVM_FFI_ICHECK_EQ(strides.size(), 2) << "The input strides length is expected to be 2. However, the given strides is " << strides; - CHECK_EQ(dilation.size(), 2) + TVM_FFI_ICHECK_EQ(dilation.size(), 2) << "The input dilation length is expected to be 2. However, the given dilation is " << dilation; diff --git a/src/relax/op/nn/nn.cc b/src/relax/op/nn/nn.cc index 0a2335834399..68051a31f7e9 100644 --- a/src/relax/op/nn/nn.cc +++ b/src/relax/op/nn/nn.cc @@ -136,17 +136,17 @@ StructInfo InferStructInfoPRelu(const Call& call, const BlockBuilder& ctx) { InferLayoutOutput InferLayoutPRelu( const Call& call, const ffi::Map>& desired_layouts, const VarLayoutMap& var_layout_map) { - ICHECK(NoDesiredLayout(call, desired_layouts)); + TVM_FFI_ICHECK(NoDesiredLayout(call, desired_layouts)); const auto* attrs = call->attrs.as(); - ICHECK(attrs) << "Invalid Call"; + TVM_FFI_ICHECK(attrs) << "Invalid Call"; LayoutDecision layout = GetLayoutDecision(var_layout_map, call->args[0]); // TODO(Siva): We could handle if the axis is not the sub indexed one. if (layout->layout.ndim() != layout->layout.ndim_primal()) { const auto* tensor_sinfo = GetStructInfoAs(call->args[0]); - ICHECK(tensor_sinfo != nullptr) << "Invalid Call"; - ICHECK(!tensor_sinfo->IsUnknownNdim()) << "Only support static ndim for now"; + TVM_FFI_ICHECK(tensor_sinfo != nullptr) << "Invalid Call"; + TVM_FFI_ICHECK(!tensor_sinfo->IsUnknownNdim()) << "Only support static ndim for now"; int ndim = tensor_sinfo->ndim; layout = LayoutDecision(InitialLayout(ndim)); } @@ -201,17 +201,17 @@ StructInfo InferStructInfoSoftmax(const Call& call, const BlockBuilder& ctx) { InferLayoutOutput InferLayoutSoftmax( const Call& call, const ffi::Map>& desired_layouts, const VarLayoutMap& var_layout_map) { - ICHECK(NoDesiredLayout(call, desired_layouts)); + TVM_FFI_ICHECK(NoDesiredLayout(call, desired_layouts)); const auto* attrs = call->attrs.as(); - ICHECK(attrs) << "Invalid Call"; + TVM_FFI_ICHECK(attrs) << "Invalid Call"; LayoutDecision layout = GetLayoutDecision(var_layout_map, call->args[0]); // TODO(Siva): We could handle if the axis is not the sub indexed one. if (layout->layout.ndim() != layout->layout.ndim_primal()) { const auto* tensor_sinfo = GetStructInfoAs(call->args[0]); - ICHECK(tensor_sinfo != nullptr) << "Invalid Call"; - ICHECK(!tensor_sinfo->IsUnknownNdim()) << "Only support static ndim for now"; + TVM_FFI_ICHECK(tensor_sinfo != nullptr) << "Invalid Call"; + TVM_FFI_ICHECK(!tensor_sinfo->IsUnknownNdim()) << "Only support static ndim for now"; int ndim = tensor_sinfo->ndim; layout = LayoutDecision(InitialLayout(ndim)); } @@ -270,7 +270,7 @@ StructInfo InferStructInfoPad(const Call& call, const BlockBuilder& ctx) { const auto* attrs = call->attrs.as(); int ndim = input_sinfo[0]->ndim; ffi::Array pad_width = attrs->pad_width; - ICHECK(static_cast(pad_width.size()) == 2 * ndim) << "Illegal pad_width"; + TVM_FFI_ICHECK(static_cast(pad_width.size()) == 2 * ndim) << "Illegal pad_width"; ffi::Array out_shape; if (input_sinfo[0]->shape.defined()) { @@ -314,11 +314,11 @@ StructInfo InferStructInfoPixelShuffle(const Call& call, const BlockBuilder& ctx ffi::Array input_sinfo = GetInputTensorStructInfo(call, ctx); const auto* attrs = call->attrs.as(); int r = attrs->upscale_factor; - ICHECK_GT(r, 0) << "Upscale factor must be positive"; + TVM_FFI_ICHECK_GT(r, 0) << "Upscale factor must be positive"; const TensorStructInfo& input = input_sinfo[0]; int ndim = input->ndim; - ICHECK_GE(ndim, 3) << "PixelShuffle requires at least 3D input tensor"; + TVM_FFI_ICHECK_GE(ndim, 3) << "PixelShuffle requires at least 3D input tensor"; if (!input->shape.defined()) { return TensorStructInfo(input->dtype, ndim); @@ -341,7 +341,7 @@ StructInfo InferStructInfoPixelShuffle(const Call& call, const BlockBuilder& ctx const auto* c_in_imm = c_in.as(); const auto* r2_imm = r_squared.as(); - ICHECK_EQ(c_in_imm->value % r2_imm->value, 0) + TVM_FFI_ICHECK_EQ(c_in_imm->value % r2_imm->value, 0) << "Number of input channels must be divisible by the square of the upscale factor"; // Output shape: @@ -482,16 +482,16 @@ StructInfo InferStructInfoBatchNorm(const Call& call, const BlockBuilder& ctx) { InferLayoutOutput InferLayoutBatchNorm( const Call& call, const ffi::Map>& desired_layouts, const VarLayoutMap& var_layout_map) { - ICHECK(NoDesiredLayout(call, desired_layouts)); + TVM_FFI_ICHECK(NoDesiredLayout(call, desired_layouts)); std::vector initial_layouts; for (size_t i = 0; i < 5; ++i) { const auto* tensor_sinfo = GetStructInfoAs(call->args[i]); - ICHECK(tensor_sinfo != nullptr) << "Invalid Call"; - ICHECK(!tensor_sinfo->IsUnknownNdim()) << "Only support known ndim"; + TVM_FFI_ICHECK(tensor_sinfo != nullptr) << "Invalid Call"; + TVM_FFI_ICHECK(!tensor_sinfo->IsUnknownNdim()) << "Only support known ndim"; initial_layouts.push_back(InitialLayoutDecision(tensor_sinfo->ndim)); } const auto* attrs = call->attrs.as(); - ICHECK(attrs) << "Invalid Call"; + TVM_FFI_ICHECK(attrs) << "Invalid Call"; LayoutDecision layout = GetLayoutDecision(var_layout_map, call->args[0]); // While dealing with sub layouts, its adviced to deal with batchnorm @@ -555,16 +555,16 @@ StructInfo InferStructInfoLayerNorm(const Call& call, const BlockBuilder& ctx) { InferLayoutOutput InferLayoutLayerNorm( const Call& call, const ffi::Map>& desired_layouts, const VarLayoutMap& var_layout_map) { - ICHECK(NoDesiredLayout(call, desired_layouts)); + TVM_FFI_ICHECK(NoDesiredLayout(call, desired_layouts)); std::vector initial_layouts; for (size_t i = 0; i < 3; ++i) { const auto* tensor_sinfo = GetStructInfoAs(call->args[i]); - ICHECK(tensor_sinfo != nullptr) << "Invalid Call"; - ICHECK(!tensor_sinfo->IsUnknownNdim()) << "Only support known ndim"; + TVM_FFI_ICHECK(tensor_sinfo != nullptr) << "Invalid Call"; + TVM_FFI_ICHECK(!tensor_sinfo->IsUnknownNdim()) << "Only support known ndim"; initial_layouts.push_back(InitialLayoutDecision(tensor_sinfo->ndim)); } const auto* attrs = call->attrs.as(); - ICHECK(attrs) << "Invalid Call"; + TVM_FFI_ICHECK(attrs) << "Invalid Call"; LayoutDecision layout = GetLayoutDecision(var_layout_map, call->args[0]); ObjectPtr new_attrs = ffi::make_object(*attrs); @@ -670,16 +670,16 @@ StructInfo InferStructInfoGroupNorm(const Call& call, const BlockBuilder& ctx) { InferLayoutOutput InferLayoutGroupNorm( const Call& call, const ffi::Map>& desired_layouts, const VarLayoutMap& var_layout_map) { - ICHECK(NoDesiredLayout(call, desired_layouts)); + TVM_FFI_ICHECK(NoDesiredLayout(call, desired_layouts)); std::vector initial_layouts; for (size_t i = 0; i < 3; ++i) { const auto* tensor_sinfo = GetStructInfoAs(call->args[i]); - ICHECK(tensor_sinfo != nullptr) << "Invalid Call"; - ICHECK(!tensor_sinfo->IsUnknownNdim()) << "Only support known ndim"; + TVM_FFI_ICHECK(tensor_sinfo != nullptr) << "Invalid Call"; + TVM_FFI_ICHECK(!tensor_sinfo->IsUnknownNdim()) << "Only support known ndim"; initial_layouts.push_back(InitialLayoutDecision(tensor_sinfo->ndim)); } const auto* attrs = call->attrs.as(); - ICHECK(attrs) << "Invalid Call"; + TVM_FFI_ICHECK(attrs) << "Invalid Call"; LayoutDecision layout = GetLayoutDecision(var_layout_map, call->args[0]); ObjectPtr new_attrs = ffi::make_object(*attrs); @@ -728,7 +728,7 @@ StructInfo InferStructInfoInstanceNorm(const Call& call, const BlockBuilder& ctx Op op = Downcast(call->op); ffi::Array input_sinfo = GetInputTensorStructInfo(call, ctx); const auto* attrs = call->attrs.as(); - ICHECK(attrs) << "Invalid Call"; + TVM_FFI_ICHECK(attrs) << "Invalid Call"; TensorStructInfo data_sinfo = input_sinfo[0]; int channel_axis = -1; @@ -773,16 +773,16 @@ StructInfo InferStructInfoInstanceNorm(const Call& call, const BlockBuilder& ctx InferLayoutOutput InferLayoutInstanceNorm( const Call& call, const ffi::Map>& desired_layouts, const VarLayoutMap& var_layout_map) { - ICHECK(NoDesiredLayout(call, desired_layouts)); + TVM_FFI_ICHECK(NoDesiredLayout(call, desired_layouts)); std::vector initial_layouts; for (size_t i = 0; i < 3; ++i) { const auto* tensor_sinfo = GetStructInfoAs(call->args[i]); - ICHECK(tensor_sinfo != nullptr) << "Invalid Call"; - ICHECK(!tensor_sinfo->IsUnknownNdim()) << "Only support known ndim"; + TVM_FFI_ICHECK(tensor_sinfo != nullptr) << "Invalid Call"; + TVM_FFI_ICHECK(!tensor_sinfo->IsUnknownNdim()) << "Only support known ndim"; initial_layouts.push_back(InitialLayoutDecision(tensor_sinfo->ndim)); } const auto* attrs = call->attrs.as(); - ICHECK(attrs) << "Invalid Call"; + TVM_FFI_ICHECK(attrs) << "Invalid Call"; LayoutDecision layout = GetLayoutDecision(var_layout_map, call->args[0]); ObjectPtr new_attrs = ffi::make_object(*attrs); @@ -836,16 +836,16 @@ StructInfo InferStructInfoRMSNorm(const Call& call, const BlockBuilder& ctx) { InferLayoutOutput InferLayoutRMSNorm( const Call& call, const ffi::Map>& desired_layouts, const VarLayoutMap& var_layout_map) { - ICHECK(NoDesiredLayout(call, desired_layouts)); + TVM_FFI_ICHECK(NoDesiredLayout(call, desired_layouts)); std::vector initial_layouts; for (size_t i = 0; i < 2; ++i) { const auto* tensor_sinfo = GetStructInfoAs(call->args[i]); - ICHECK(tensor_sinfo != nullptr) << "Invalid Call"; - ICHECK(!tensor_sinfo->IsUnknownNdim()) << "Only support known ndim"; + TVM_FFI_ICHECK(tensor_sinfo != nullptr) << "Invalid Call"; + TVM_FFI_ICHECK(!tensor_sinfo->IsUnknownNdim()) << "Only support known ndim"; initial_layouts.push_back(InitialLayoutDecision(tensor_sinfo->ndim)); } const auto* attrs = call->attrs.as(); - ICHECK(attrs) << "Invalid Call"; + TVM_FFI_ICHECK(attrs) << "Invalid Call"; LayoutDecision layout = GetLayoutDecision(var_layout_map, call->args[0]); ObjectPtr new_attrs = ffi::make_object(*attrs); @@ -966,7 +966,7 @@ Expr nll_loss(Expr predictions, Expr targets, ffi::Optional weights, ffi:: int ignore_index) { ObjectPtr attrs = ffi::make_object(); - ICHECK(reduction == "none" || reduction == "sum" || reduction == "mean") + TVM_FFI_ICHECK(reduction == "none" || reduction == "sum" || reduction == "mean") << "The argument reduction of NLLLoss should be one of the following " "values: none, mean, sum. However, the given value is " << reduction; @@ -1078,12 +1078,12 @@ StructInfo InferStructInfoNLLLoss(const Call& call, const BlockBuilder& ctx) { if (pred_shape_value.defined()) { if (pred_shape_value.value().size() == 1) { // (C,) - ICHECK(pred_sinfo->ndim == 1); + TVM_FFI_ICHECK(pred_sinfo->ndim == 1); C = pred_shape_value.value()[0]; } else { // (N, C, d1, d2, ..., dk) - ICHECK(pred_shape_value.value().size() >= 2); - ICHECK(pred_sinfo->ndim == static_cast(pred_shape_value.value().size())); + TVM_FFI_ICHECK(pred_shape_value.value().size() >= 2); + TVM_FFI_ICHECK(pred_sinfo->ndim == static_cast(pred_shape_value.value().size())); N = pred_shape_value.value()[0]; C = pred_shape_value.value()[1]; output_shape = ffi::Array(); @@ -1101,7 +1101,7 @@ StructInfo InferStructInfoNLLLoss(const Call& call, const BlockBuilder& ctx) { if (tgt_shape_value.defined()) { if (tgt_shape_value.value().empty()) { // () - ICHECK(tgt_sinfo->ndim == 0); + TVM_FFI_ICHECK(tgt_sinfo->ndim == 0); if (N.defined()) { ctx->ReportFatal(Diagnostic::Error(call) << "Shape mismatch for NLLLoss. Predictions shape is " @@ -1126,12 +1126,12 @@ StructInfo InferStructInfoNLLLoss(const Call& call, const BlockBuilder& ctx) { if (tgt_shape_value.value().size() == 1) { // (N,) - ICHECK(tgt_sinfo->IsUnknownNdim() || tgt_sinfo->ndim == 1); + TVM_FFI_ICHECK(tgt_sinfo->IsUnknownNdim() || tgt_sinfo->ndim == 1); } else { // (N, d1, d2, ..., dk) - ICHECK(tgt_shape_value.value().size() >= 2); - ICHECK(tgt_sinfo->IsUnknownNdim() || - tgt_sinfo->ndim == static_cast(tgt_shape_value.value().size())); + TVM_FFI_ICHECK(tgt_shape_value.value().size() >= 2); + TVM_FFI_ICHECK(tgt_sinfo->IsUnknownNdim() || + tgt_sinfo->ndim == static_cast(tgt_shape_value.value().size())); if (pred_shape_value.defined()) { // check (d1, d2, ..., dk) @@ -1154,8 +1154,8 @@ StructInfo InferStructInfoNLLLoss(const Call& call, const BlockBuilder& ctx) { wgt_shape_value = GetStructInfoAs(wgt_sinfo->shape.value())->values; } if (wgt_shape_value.defined()) { - ICHECK(wgt_shape_value.value().size() == 1); - ICHECK(wgt_sinfo->IsUnknownNdim() || wgt_sinfo->ndim == 1); + TVM_FFI_ICHECK(wgt_shape_value.value().size() == 1); + TVM_FFI_ICHECK(wgt_sinfo->IsUnknownNdim() || wgt_sinfo->ndim == 1); const PrimExpr& C_wgt = wgt_shape_value.value()[0]; if (C.defined() && analyzer->CanProve(C.value() != C_wgt)) { ctx->ReportFatal(Diagnostic::Error(call) diff --git a/src/relax/op/nn/pooling.cc b/src/relax/op/nn/pooling.cc index 2397bf009866..3e963cfd145f 100644 --- a/src/relax/op/nn/pooling.cc +++ b/src/relax/op/nn/pooling.cc @@ -44,12 +44,12 @@ Expr MakePool1d(ffi::String op_name, Expr data, ffi::Array pool_size, ffi::String layout, ffi::Optional out_layout) { padding = GetCompletePadding1D(std::move(padding)); - CHECK_EQ(pool_size.size(), 1) + TVM_FFI_ICHECK_EQ(pool_size.size(), 1) << "The input pool_size length is expected to be 1. However, the given pool_size is " << pool_size; - CHECK_EQ(strides.size(), 1) + TVM_FFI_ICHECK_EQ(strides.size(), 1) << "The input strides length is expected to be 1. However, the given strides is " << strides; - CHECK_EQ(dilation.size(), 1) + TVM_FFI_ICHECK_EQ(dilation.size(), 1) << "The input dilation length is expected to be 1. However, the given dilation is " << dilation; @@ -127,12 +127,12 @@ StructInfo InferStructInfoPool1D(const Call& call, const BlockBuilder& ctx) { InferLayoutOutput InferLayoutPool1d( const Call& call, const ffi::Map>& desired_layouts, const VarLayoutMap& var_layout_map) { - ICHECK(NoDesiredLayout(call, desired_layouts)); + TVM_FFI_ICHECK(NoDesiredLayout(call, desired_layouts)); const auto* tensor_sinfo = GetStructInfoAs(call); - ICHECK(tensor_sinfo != nullptr) << "Invalid Call"; - ICHECK_EQ(tensor_sinfo->ndim, 3) << "Unsupported initial layout"; + TVM_FFI_ICHECK(tensor_sinfo != nullptr) << "Invalid Call"; + TVM_FFI_ICHECK_EQ(tensor_sinfo->ndim, 3) << "Unsupported initial layout"; const auto* attrs = call->attrs.as(); - ICHECK(attrs) << "Invalid Call"; + TVM_FFI_ICHECK(attrs) << "Invalid Call"; LayoutDecision layout = GetLayoutDecision(var_layout_map, call->args[0]); ObjectPtr new_attrs = ffi::make_object(*attrs); @@ -167,12 +167,12 @@ Expr MakePool2d(ffi::String op_name, Expr data, ffi::Array pool_size, dilation.push_back(dilation[0]); } - CHECK_EQ(pool_size.size(), 2) + TVM_FFI_ICHECK_EQ(pool_size.size(), 2) << "The input pool_size length is expected to be 2. However, the given pool_size is " << pool_size; - CHECK_EQ(strides.size(), 2) + TVM_FFI_ICHECK_EQ(strides.size(), 2) << "The input strides length is expected to be 2. However, the given strides is " << strides; - CHECK_EQ(dilation.size(), 2) + TVM_FFI_ICHECK_EQ(dilation.size(), 2) << "The input dilation length is expected to be 2. However, the given dilation is " << dilation; @@ -260,12 +260,12 @@ StructInfo InferStructInfoPool2D(const Call& call, const BlockBuilder& ctx) { InferLayoutOutput InferLayoutPool2d( const Call& call, const ffi::Map>& desired_layouts, const VarLayoutMap& var_layout_map) { - ICHECK(NoDesiredLayout(call, desired_layouts)); + TVM_FFI_ICHECK(NoDesiredLayout(call, desired_layouts)); const auto* tensor_sinfo = GetStructInfoAs(call); - ICHECK(tensor_sinfo != nullptr) << "Invalid Call"; - ICHECK_EQ(tensor_sinfo->ndim, 4) << "Unsupported initial layout"; + TVM_FFI_ICHECK(tensor_sinfo != nullptr) << "Invalid Call"; + TVM_FFI_ICHECK_EQ(tensor_sinfo->ndim, 4) << "Unsupported initial layout"; const auto* attrs = call->attrs.as(); - ICHECK(attrs) << "Invalid Call"; + TVM_FFI_ICHECK(attrs) << "Invalid Call"; LayoutDecision layout = GetLayoutDecision(var_layout_map, call->args[0]); ObjectPtr new_attrs = ffi::make_object(*attrs); @@ -321,12 +321,12 @@ Expr MakePool3d(ffi::String op_name, Expr data, ffi::Array pool_size, dilation.push_back(dilation[0]); } - CHECK_EQ(pool_size.size(), 3) + TVM_FFI_ICHECK_EQ(pool_size.size(), 3) << "The input pool_size length is expected to be 3. However, the given pool_size is " << pool_size; - CHECK_EQ(strides.size(), 3) + TVM_FFI_ICHECK_EQ(strides.size(), 3) << "The input strides length is expected to be 3. However, the given strides is " << strides; - CHECK_EQ(dilation.size(), 3) + TVM_FFI_ICHECK_EQ(dilation.size(), 3) << "The input dilation length is expected to be 3. However, the given dilation is " << dilation; @@ -424,12 +424,12 @@ StructInfo InferStructInfoPool3D(const Call& call, const BlockBuilder& ctx) { InferLayoutOutput InferLayoutPool3d( const Call& call, const ffi::Map>& desired_layouts, const VarLayoutMap& var_layout_map) { - ICHECK(NoDesiredLayout(call, desired_layouts)); + TVM_FFI_ICHECK(NoDesiredLayout(call, desired_layouts)); const auto* tensor_sinfo = GetStructInfoAs(call); - ICHECK(tensor_sinfo != nullptr) << "Invalid Call"; - ICHECK_EQ(tensor_sinfo->ndim, 5) << "Unsupported initial layout"; + TVM_FFI_ICHECK(tensor_sinfo != nullptr) << "Invalid Call"; + TVM_FFI_ICHECK_EQ(tensor_sinfo->ndim, 5) << "Unsupported initial layout"; const auto* attrs = call->attrs.as(); - ICHECK(attrs) << "Invalid Call"; + TVM_FFI_ICHECK(attrs) << "Invalid Call"; LayoutDecision layout = GetLayoutDecision(var_layout_map, call->args[0]); ObjectPtr new_attrs = ffi::make_object(*attrs); @@ -522,7 +522,7 @@ Expr adaptive_avg_pool1d(Expr data, ffi::Optional> output_si attrs->out_layout = out_layout.value_or(layout); if (output_size.defined()) { ffi::Array _output_size = output_size.value(); - CHECK_EQ(_output_size.size(), 1) + TVM_FFI_ICHECK_EQ(_output_size.size(), 1) << "The output_size length is expected to be 1. However, the given output_size is " << _output_size; attrs->output_size = std::move(_output_size); @@ -572,12 +572,12 @@ StructInfo InferStructInfoAdaptiveAvgPool1D(const Call& call, const BlockBuilder InferLayoutOutput InferLayoutAdaptiveAvgPool1D( const Call& call, const ffi::Map>& desired_layouts, const VarLayoutMap& var_layout_map) { - ICHECK(NoDesiredLayout(call, desired_layouts)); + TVM_FFI_ICHECK(NoDesiredLayout(call, desired_layouts)); const auto* tensor_sinfo = GetStructInfoAs(call); - ICHECK(tensor_sinfo != nullptr) << "Invalid Call"; - ICHECK_EQ(tensor_sinfo->ndim, 3) << "Unsupported initial layout"; + TVM_FFI_ICHECK(tensor_sinfo != nullptr) << "Invalid Call"; + TVM_FFI_ICHECK_EQ(tensor_sinfo->ndim, 3) << "Unsupported initial layout"; const auto* attrs = call->attrs.as(); - ICHECK(attrs) << "Invalid Call"; + TVM_FFI_ICHECK(attrs) << "Invalid Call"; LayoutDecision layout = GetLayoutDecision(var_layout_map, call->args[0]); ObjectPtr new_attrs = ffi::make_object(*attrs); @@ -607,7 +607,7 @@ Expr adaptive_avg_pool2d(Expr data, ffi::Optional> output_si if (_output_size.size() == 1) { _output_size.push_back(_output_size[0]); } - CHECK_EQ(_output_size.size(), 2) + TVM_FFI_ICHECK_EQ(_output_size.size(), 2) << "The output_size length is expected to be 2. However, the given output_size is " << _output_size; attrs->output_size = std::move(_output_size); @@ -658,12 +658,12 @@ StructInfo InferStructInfoAdaptiveAvgPool2D(const Call& call, const BlockBuilder InferLayoutOutput InferLayoutAdaptiveAvgPool2D( const Call& call, const ffi::Map>& desired_layouts, const VarLayoutMap& var_layout_map) { - ICHECK(NoDesiredLayout(call, desired_layouts)); + TVM_FFI_ICHECK(NoDesiredLayout(call, desired_layouts)); const auto* tensor_sinfo = GetStructInfoAs(call); - ICHECK(tensor_sinfo != nullptr) << "Invalid Call"; - ICHECK_EQ(tensor_sinfo->ndim, 4) << "Unsupported initial layout"; + TVM_FFI_ICHECK(tensor_sinfo != nullptr) << "Invalid Call"; + TVM_FFI_ICHECK_EQ(tensor_sinfo->ndim, 4) << "Unsupported initial layout"; const auto* attrs = call->attrs.as(); - ICHECK(attrs) << "Invalid Call"; + TVM_FFI_ICHECK(attrs) << "Invalid Call"; LayoutDecision layout = GetLayoutDecision(var_layout_map, call->args[0]); ObjectPtr new_attrs = ffi::make_object(*attrs); @@ -709,7 +709,7 @@ Expr adaptive_avg_pool3d(Expr data, ffi::Optional> output_si if (_output_size.size() == 1) { _output_size.push_back(_output_size[0]); } - CHECK_EQ(_output_size.size(), 3) + TVM_FFI_ICHECK_EQ(_output_size.size(), 3) << "The output_size length is expected to be 3. However, the given output_size is " << _output_size; attrs->output_size = std::move(_output_size); @@ -761,12 +761,12 @@ StructInfo InferStructInfoAdaptiveAvgPool3D(const Call& call, const BlockBuilder InferLayoutOutput InferLayoutAdaptiveAvgPool3D( const Call& call, const ffi::Map>& desired_layouts, const VarLayoutMap& var_layout_map) { - ICHECK(NoDesiredLayout(call, desired_layouts)); + TVM_FFI_ICHECK(NoDesiredLayout(call, desired_layouts)); const auto* tensor_sinfo = GetStructInfoAs(call); - ICHECK(tensor_sinfo != nullptr) << "Invalid Call"; - ICHECK_EQ(tensor_sinfo->ndim, 5) << "Unsupported initial layout"; + TVM_FFI_ICHECK(tensor_sinfo != nullptr) << "Invalid Call"; + TVM_FFI_ICHECK_EQ(tensor_sinfo->ndim, 5) << "Unsupported initial layout"; const auto* attrs = call->attrs.as(); - ICHECK(attrs) << "Invalid Call"; + TVM_FFI_ICHECK(attrs) << "Invalid Call"; LayoutDecision layout = GetLayoutDecision(var_layout_map, call->args[0]); ObjectPtr new_attrs = ffi::make_object(*attrs); diff --git a/src/relax/op/op.cc b/src/relax/op/op.cc index d7d68766dd1a..e78b75fe14ed 100644 --- a/src/relax/op/op.cc +++ b/src/relax/op/op.cc @@ -69,8 +69,8 @@ StructInfo InferStructInfoShapeOf(const Call& call, const BlockBuilder& ctx) { // use the StructInfo of the argument auto arg_sinfo = GetStructInfo(call->args[0]); auto* tensor_sinfo = GetStructInfo(call->args[0]).as(); - CHECK(tensor_sinfo) << "shape_of expects a tensor input, but received " << arg_sinfo - << "; use MatchCast if necessary"; + TVM_FFI_ICHECK(tensor_sinfo) << "shape_of expects a tensor input, but received " << arg_sinfo + << "; use MatchCast if necessary"; if (tensor_sinfo->ndim == kUnknownNDim) { return ShapeStructInfo(kUnknownNDim); } @@ -80,7 +80,7 @@ StructInfo InferStructInfoShapeOf(const Call& call, const BlockBuilder& ctx) { } // otherwise, copy over the values from the tensor shape auto* tensor_shape = tensor_sinfo->shape.as(); - CHECK(tensor_shape); + TVM_FFI_ICHECK(tensor_shape); return ShapeStructInfo(tensor_shape->values); } @@ -94,12 +94,13 @@ StructInfo InferStructInfoCallPurePacked(const Call& call, const BlockBuilder& c // the callee must be an opaque function auto callee = call->args[0]; - ICHECK(!callee.as()) << "call_pure_packed cannot be used with an op node"; + TVM_FFI_ICHECK(!callee.as()) << "call_pure_packed cannot be used with an op node"; auto opt = MatchStructInfo(callee); - ICHECK(opt) << "Callee must have a function struct info"; + TVM_FFI_ICHECK(opt) << "Callee must have a function struct info"; FuncStructInfo finfo = opt.value(); - ICHECK(finfo->IsOpaque()) << "call_pure_packed must be called with an opaque function, but " - << callee << " is not opaque"; + TVM_FFI_ICHECK(finfo->IsOpaque()) + << "call_pure_packed must be called with an opaque function, but " << callee + << " is not opaque"; // same logic as from DeriveCallRetStructInfo for ordinary calls if (finfo->derive_func.defined()) { @@ -147,12 +148,13 @@ StructInfo InferStructInfoCallInplacePacked(const Call& call, const BlockBuilder // the callee must be an opaque function auto callee = call->args[0]; - ICHECK(!callee.as()) << "call_pure_packed cannot be used with an op node"; + TVM_FFI_ICHECK(!callee.as()) << "call_pure_packed cannot be used with an op node"; auto opt = MatchStructInfo(callee); - ICHECK(opt) << "Callee must have a function struct info"; + TVM_FFI_ICHECK(opt) << "Callee must have a function struct info"; FuncStructInfo finfo = opt.value(); - ICHECK(finfo->IsOpaque()) << "call_pure_packed must be called with an opaque function, but " - << callee << " is not opaque"; + TVM_FFI_ICHECK(finfo->IsOpaque()) + << "call_pure_packed must be called with an opaque function, but " << callee + << " is not opaque"; // check the range for inplace indices, make sure at least one is not -1, ensure they're unique const auto* attrs = call->attrs.as(); @@ -290,22 +292,20 @@ static ffi::Optional InferCallTIROutputStructInfoFromArguments( StructInfo func_sinfo, StructInfo arg_sinfo, ffi::Optional packed_ints_sinfo, ffi::Optional> opt_inplace_indices) { auto opt_callee_sinfo = func_sinfo.as(); - CHECK(opt_callee_sinfo) << "TypeError: " - << "The first argument to `R.call_tir` must be a function, " - << "but instead received argument of type " << func_sinfo; + TVM_FFI_CHECK(opt_callee_sinfo, TypeError) + << "The first argument to `R.call_tir` must be a function, " + << "but instead received argument of type " << func_sinfo; auto callee_sinfo = opt_callee_sinfo.value(); - CHECK(callee_sinfo->params.defined()) - << "ValueError: " + TVM_FFI_CHECK(callee_sinfo->params.defined(), ValueError) << "The first argument to `R.call_tir` must be a function " << "with known argument types. " << "However, the first argument was of type " << callee_sinfo; auto callee_params = callee_sinfo->params.value(); const TupleStructInfoNode* args = arg_sinfo.as(); - CHECK(args) << "TypeError: " - << "The second argument to `R.call_tir` must be a tuple, " - << "but instead received expression of type " << arg_sinfo; + TVM_FFI_CHECK(args, TypeError) << "The second argument to `R.call_tir` must be a tuple, " + << "but instead received expression of type " << arg_sinfo; // R.call_tir expects the PrimFunc to have three groups of arguments. // @@ -322,8 +322,7 @@ static ffi::Optional InferCallTIROutputStructInfoFromArguments( if (packed_ints_sinfo) { auto packed_sinfo = packed_ints_sinfo.value(); packed_tuple_sinfo = packed_sinfo.as(); - CHECK(packed_tuple_sinfo && !packed_tuple_sinfo->IsUnknownNdim()) - << "TypeError: " + TVM_FFI_CHECK(packed_tuple_sinfo && !packed_tuple_sinfo->IsUnknownNdim(), TypeError) << "The third argument to `R.call_tir`, if present, " << "must be a ffi::Shape with known dimensionality. " << "However, the argument received was of type " << packed_sinfo; @@ -332,8 +331,8 @@ static ffi::Optional InferCallTIROutputStructInfoFromArguments( num_trailing_int_arguments = 0; } - CHECK_LE(num_input_arguments + num_trailing_int_arguments, callee_params.size()) - << "ValueError: " + TVM_FFI_CHECK_LE(num_input_arguments + num_trailing_int_arguments, callee_params.size(), + ValueError) << "R.call_tir attempted to call a function using " << num_input_arguments << " input arguments and " << num_trailing_int_arguments << " trailing integer arguments. " << "However, the callee only accepts " << callee_params.size() << " arguments in total."; @@ -411,7 +410,7 @@ static ffi::Optional InferCallTIROutputStructInfoFromArguments( [](const StructInfo& sinfo) -> Expr { return Var("dummy_leading_arg", sinfo); }); for (size_t i = 0; i < num_trailing_int_arguments; i++) { - ICHECK(packed_tuple_sinfo); + TVM_FFI_ICHECK(packed_tuple_sinfo); PrimStructInfo dummy_arg_sinfo = [&]() { if (packed_tuple_sinfo->values) { return PrimStructInfo(packed_tuple_sinfo->values.value()[i]); @@ -437,7 +436,7 @@ StructInfo InferStructInfoCallTIR(const Call& call, const BlockBuilder& ctx) { ctx->ReportFatal(Diagnostic::Error(call) << "sinfo_args should have exactly 1 output struct info."); } - CHECK(call->args[0]->IsInstance()) + TVM_FFI_ICHECK(call->args[0]->IsInstance()) << "R.call_tir expects the first argument to be a GlobalVar referring to a TIR PrimFunc. " << "However, the argument " << call->args[0] << " instead has type " << call->args[0]->GetTypeKey(); @@ -453,24 +452,24 @@ Expr NormalizeCallTIR(const BlockBuilder& ctx, Call call) { // `relax.call_tir_inplace`. Therefore, all error messages should // be written in terms of `call->op`, and should not explicitly // reference the `relax.call_tir` operator.` - CHECK(call->args.size() == 2 || call->args.size() == 3) + TVM_FFI_ICHECK(call->args.size() == 2 || call->args.size() == 3) << "Operation " << call->op << " expects either two arguments [callee, arg_tuple], " << "or three arguments [callee, arg_tuple, tir_args], " << "but " << call << " has " << call->args.size() << " arguments."; auto callee = call->args[0]; - CHECK(callee->struct_info_.as()) + TVM_FFI_ICHECK(callee->struct_info_.as()) << "Operation " << call->op << " expects the first argument to be a TIR callee. " << "However, the first argument " << callee << " has struct info " << callee->struct_info_; Expr arg_tuple = call->args[1]; - CHECK(arg_tuple->struct_info_.as()) + TVM_FFI_ICHECK(arg_tuple->struct_info_.as()) << "Operation " << call->op << " expects the second argument to be a tuple of relax Expr. " << "However, the second argument " << arg_tuple << " has struct info " << arg_tuple->struct_info_ << "."; - CHECK(arg_tuple.as() || arg_tuple.as()) + TVM_FFI_ICHECK(arg_tuple.as() || arg_tuple.as()) << "Operation " << call->op << " must hold its arguments as an in-line tuple. " << "However, " << call << " has arguments " << arg_tuple << ", which is neither an in-line tuple, " @@ -478,14 +477,14 @@ Expr NormalizeCallTIR(const BlockBuilder& ctx, Call call) { if (call->args.size() > 2) { Expr packed_ints = call->args[2]; - CHECK(packed_ints->struct_info_.as()) + TVM_FFI_ICHECK(packed_ints->struct_info_.as()) << "Operation " << call->op << " expects the optional third argument, " << "if present, to be a ffi::Shape. " << "However, the third argument " << packed_ints << " has struct info " << packed_ints->struct_info_; } - CHECK_EQ(call->sinfo_args.size(), 1) + TVM_FFI_ICHECK_EQ(call->sinfo_args.size(), 1) << "R.call_tir should have exactly one `sinfo_args` parameter, " << "which defines the output of the PrimFunc."; @@ -567,8 +566,7 @@ void ValidateCallTIR(Call call) { auto inferred_sinfo = InferCallTIROutputStructInfoFromArguments( GetStructInfo(callee), GetStructInfo(arg_tuple), packed_int_sinfo, opt_inplace_indices); if (inferred_sinfo.defined()) { - CHECK(IsBaseOf(inferred_sinfo.value(), explicit_sinfo)) - << "TypeError: " + TVM_FFI_CHECK(IsBaseOf(inferred_sinfo.value(), explicit_sinfo), TypeError) << "The `out_sinfo` argument for R.call_tir must be compatible with the PrimFunc. " << "However, the PrimFunc's signature implies that the output should be " << inferred_sinfo << ", but the `out_sinfo` argument was " << explicit_sinfo; @@ -591,9 +589,10 @@ Expr MakeCallTIR(Expr func, Tuple args, ffi::Array out_sinfo_l ffi::Optional packed_ints) { for (const TensorStructInfo& sinfo : out_sinfo_list) { const auto* shape = sinfo->shape.as(); - CHECK(shape != nullptr) << "out_sinfo of call_tir should have defined ShapeExpr as shape. " - "However, one given structure info is " - << sinfo; + TVM_FFI_ICHECK(shape != nullptr) + << "out_sinfo of call_tir should have defined ShapeExpr as shape. " + "However, one given structure info is " + << sinfo; } StructInfo out_sinfo{nullptr}; @@ -639,7 +638,7 @@ Expr MakeCallTIRWithGrad(Expr func, Tuple args, ffi::Array out ffi::Optional packed_ints) { for (const TensorStructInfo& sinfo : out_sinfo_list) { const auto* shape = sinfo->shape.as(); - CHECK(shape != nullptr) + TVM_FFI_ICHECK(shape != nullptr) << "out_sinfo of call_tir_with_grad should have defined ShapeExpr as shape. " "However, one given structure info is " << sinfo; @@ -691,7 +690,7 @@ Expr NormalizeCallTIRInPlace(const BlockBuilder& ctx, Call call) { // there must be an inplace index for each output const auto* attrs = call->attrs.as(); - ICHECK(attrs); + TVM_FFI_ICHECK(attrs); if (attrs->inplace_indices.size() != sinfo_outputs.size()) { ctx->ReportFatal(Diagnostic::Error(call) << "There must be an in-place index specified for each output"); @@ -784,9 +783,10 @@ Expr MakeCallTIRInplace(Expr func, Tuple args, ffi::Array inplace_indic ffi::Optional packed_ints) { for (const TensorStructInfo& sinfo : out_sinfo_list) { const auto* shape = sinfo->shape.as(); - CHECK(shape != nullptr) << "out_sinfo of call_tir should have defined ShapeExpr as shape. " - "However, one given structure info is " - << sinfo; + TVM_FFI_ICHECK(shape != nullptr) + << "out_sinfo of call_tir should have defined ShapeExpr as shape. " + "However, one given structure info is " + << sinfo; } ObjectPtr attrs = ffi::make_object(); @@ -837,7 +837,7 @@ TVM_REGISTER_OP("relax.call_dps_packed") Expr MakeCallDPSPacked(Expr func, Tuple args, ffi::Array out_sinfo_list) { for (const TensorStructInfo& sinfo : out_sinfo_list) { const auto* shape = sinfo->shape.as(); - CHECK(shape != nullptr) + TVM_FFI_ICHECK(shape != nullptr) << "out_sinfo of call_dps_packed should have defined ShapeExpr as shape. " "However, one given structure info is " << sinfo; @@ -872,19 +872,19 @@ StructInfo InferStructInfoCallPyFunc(const Call& call, const BlockBuilder& ctx) void ValidateCallPyFunc(Call call) { // Validate that the function name is a string literal auto func_name = call->args[0]; - CHECK(func_name->IsInstance()) + TVM_FFI_ICHECK(func_name->IsInstance()) << "Operation " << call->op << " expects the first argument to be a string literal " << "specifying the Python function name. However, the first argument " << func_name << " is not a string literal."; // Validate that args is a tuple Expr arg_tuple = call->args[1]; - CHECK(arg_tuple->struct_info_.as()) + TVM_FFI_ICHECK(arg_tuple->struct_info_.as()) << "Operation " << call->op << " expects the second argument to be a tuple of relax Expr. " << "However, the second argument " << arg_tuple << " has struct info " << arg_tuple->struct_info_ << "."; - CHECK(arg_tuple.as() || arg_tuple.as()) + TVM_FFI_ICHECK(arg_tuple.as() || arg_tuple.as()) << "Operation " << call->op << " must hold its arguments as an in-line tuple. " << "However, " << call << " has arguments " << arg_tuple << ", which is neither an in-line tuple, " @@ -902,9 +902,10 @@ TVM_REGISTER_OP("relax.call_py_func") Expr MakeCallPyFunc(StringImm func_name, Tuple args, ffi::Array out_sinfo_list) { for (const TensorStructInfo& sinfo : out_sinfo_list) { const auto* shape = sinfo->shape.as(); - CHECK(shape != nullptr) << "out_sinfo of call_py_func should have defined ShapeExpr as shape. " - "However, one given structure info is " - << sinfo; + TVM_FFI_ICHECK(shape != nullptr) + << "out_sinfo of call_py_func should have defined ShapeExpr as shape. " + "However, one given structure info is " + << sinfo; } StructInfo out_sinfo{nullptr}; @@ -929,7 +930,7 @@ StructInfo InferStructInfoCallBuiltinWithCtx(const Call& call, const BlockBuilde // by default return void. return TupleStructInfo(ffi::Array()); } else { - ICHECK_EQ(call->sinfo_args.size(), 1); + TVM_FFI_ICHECK_EQ(call->sinfo_args.size(), 1); return call->sinfo_args[0]; } } @@ -1130,8 +1131,8 @@ TVM_FFI_STATIC_INIT_BLOCK() { StructInfo InferStructInfoSize(const Call& call, const BlockBuilder& ctx) { auto arg_sinfo = GetStructInfo(call->args[0]); auto* tensor_sinfo = GetStructInfo(call->args[0]).as(); - CHECK(tensor_sinfo) << "size expects a tensor input, but received " << arg_sinfo - << "; use MatchCast if necessary"; + TVM_FFI_ICHECK(tensor_sinfo) << "size expects a tensor input, but received " << arg_sinfo + << "; use MatchCast if necessary"; return TensorStructInfo(ShapeExpr(ffi::Array{}), DataType::Int(64)); } @@ -1154,13 +1155,13 @@ TVM_FFI_STATIC_INIT_BLOCK() { // tensor_to_shape StructInfo ReturnTensorToShapeStructInfo(const Call& call, const BlockBuilder& ctx) { - ICHECK(call->args.size() == 1); - ICHECK(call->args[0]->struct_info_.defined()); + TVM_FFI_ICHECK(call->args.size() == 1); + TVM_FFI_ICHECK(call->args[0]->struct_info_.defined()); const auto* tsinfo = GetStructInfoAs(call->args[0]); - ICHECK(tsinfo); - ICHECK_EQ(tsinfo->ndim, 1) << "relax.tensor_to_shape expected argument to be 1-d, " - << "but " << call << " has argument " << call->args[0] - << " with struct info " << call->args[0]->struct_info_; + TVM_FFI_ICHECK(tsinfo); + TVM_FFI_ICHECK_EQ(tsinfo->ndim, 1) << "relax.tensor_to_shape expected argument to be 1-d, " + << "but " << call << " has argument " << call->args[0] + << " with struct info " << call->args[0]->struct_info_; if (tsinfo->shape.defined()) { ShapeExpr shape_expr = Downcast(tsinfo->shape.value()); @@ -1190,10 +1191,10 @@ TVM_FFI_STATIC_INIT_BLOCK() { // shape_to_tensor StructInfo ReturnShapeToTensorStructInfo(const Call& call, const BlockBuilder& ctx) { - ICHECK(call->args.size() == 1); - ICHECK(call->args[0]->struct_info_.defined()); + TVM_FFI_ICHECK(call->args.size() == 1); + TVM_FFI_ICHECK(call->args[0]->struct_info_.defined()); const auto* sinfo = GetStructInfoAs(call->args[0]); - ICHECK(sinfo); + TVM_FFI_ICHECK(sinfo); int32_t ndim = sinfo->ndim; return TensorStructInfo(ShapeExpr({PrimExpr(ndim)}), DataType::Int(64)); } @@ -1218,9 +1219,9 @@ TVM_FFI_STATIC_INIT_BLOCK() { // alloc_tensor StructInfo InferStructInfoAllocateTensor(const Call& call, const BlockBuilder& ctx) { - ICHECK(call->args[0].as()) + TVM_FFI_ICHECK(call->args[0].as()) << "must be ShapeExpr, but got " << call->args[0]->GetTypeKey(); - ICHECK(call->args[1].as()) + TVM_FFI_ICHECK(call->args[1].as()) << "must be DataTypeImm, but got " << call->args[1]->GetTypeKey(); DataType out_dtype; if (const auto* dtype_node = call->args[1].as()) { @@ -1295,7 +1296,7 @@ TVM_FFI_STATIC_INIT_BLOCK() { // memory planning alloc_tensor StructInfo InferStructInfoMemAllocTensor(const Call& call, const BlockBuilder& ctx) { - ICHECK(GetStructInfoAs(call->args[2])) + TVM_FFI_ICHECK(GetStructInfoAs(call->args[2])) << "must be a Expr of ShapeStructInfo, but got " << call->args[1]->GetTypeKey(); DataType out_dtype; if (const auto* dtype_node = call->args[3].as()) { @@ -1542,8 +1543,8 @@ TVM_FFI_STATIC_INIT_BLOCK() { // to_vdevice StructInfo InferToVDeviceStructInfo(const Call& call, const BlockBuilder& ctx) { - ICHECK(call->args.size() == 1); - ICHECK(call->args[0]->struct_info_.defined()); + TVM_FFI_ICHECK(call->args.size() == 1); + TVM_FFI_ICHECK(call->args[0]->struct_info_.defined()); TensorStructInfo data_sinfo = GetUnaryInputTensorStructInfo(call, ctx); auto attrs = call->attrs.as(); VDevice vdev = attrs->dst_vdevice; @@ -1575,8 +1576,8 @@ TVM_FFI_STATIC_INIT_BLOCK() { // hint_on_device StructInfo InferHintOnDeviceStructInfo(const Call& call, const BlockBuilder& ctx) { - ICHECK(call->args.size() == 1); - ICHECK(call->args[0]->struct_info_.defined()); + TVM_FFI_ICHECK(call->args.size() == 1); + TVM_FFI_ICHECK(call->args[0]->struct_info_.defined()); TensorStructInfo data_sinfo = GetUnaryInputTensorStructInfo(call, ctx); return data_sinfo; } diff --git a/src/relax/op/op_common.cc b/src/relax/op/op_common.cc index 5b9ed1e5f529..d084f0c0eb10 100644 --- a/src/relax/op/op_common.cc +++ b/src/relax/op/op_common.cc @@ -48,10 +48,10 @@ void CheckNumArguments(const Call& call, const BlockBuilder& ctx) { TensorStructInfo GetInputTensorStructInfo(const Call& call, size_t i_arg, const BlockBuilder& ctx) { Op op = Downcast(call->op); - ICHECK_EQ(op->arguments.size(), call->args.size()) + TVM_FFI_ICHECK_EQ(op->arguments.size(), call->args.size()) << "Failure caught by this check " << "should have previously been caught by `CheckNumArguments`"; - ICHECK_LT(i_arg, op->arguments.size()); + TVM_FFI_ICHECK_LT(i_arg, op->arguments.size()); auto arg = call->args[i_arg]; auto sinfo = GetStructInfo(arg); @@ -148,7 +148,7 @@ ffi::Optional> InferBinaryBroadcastShape( std::vector NormalizeAxes(const Call& call, const BlockBuilder& ctx, int ndim, const ffi::Array& axes) { - ICHECK_NE(ndim, kUnknownNDim) << "The ndim is required to be known for this function."; + TVM_FFI_ICHECK_NE(ndim, kUnknownNDim) << "The ndim is required to be known for this function."; std::vector appeared_dims_set; std::vector axes_non_neg; appeared_dims_set.resize(ndim, /*value=*/false); @@ -180,7 +180,7 @@ std::vector NormalizeAxes(const Call& call, const BlockBuilder& ctx, int nd InferLayoutOutput InferLayoutUnaryEwise( const Call& call, const ffi::Map>& desired_layouts, const VarLayoutMap& var_layout_map) { - ICHECK(NoDesiredLayout(call, desired_layouts)); + TVM_FFI_ICHECK(NoDesiredLayout(call, desired_layouts)); LayoutDecision layout = GetLayoutDecision(var_layout_map, call->args[0]); return InferLayoutOutput({layout}, {layout}, Attrs(call->attrs)); } diff --git a/src/relax/op/op_common.h b/src/relax/op/op_common.h index cd5406e614c0..19e0398892ab 100644 --- a/src/relax/op/op_common.h +++ b/src/relax/op/op_common.h @@ -146,7 +146,7 @@ std::tuple GetArgStructInfo(const Call& call, const BlockBuilder& c // Unfortunately, because the `.add_argument()` calls in // TVM_REGISTER_OP occur during initialization of globals and are // not available at compile-time, this cannot be a static_assert. - ICHECK_EQ(n_input, sizeof...(ArgTypes)) + TVM_FFI_ICHECK_EQ(n_input, sizeof...(ArgTypes)) << "Internal error: " << op << " op defines " << n_input << " arguments in its TVM_REGISTER_OP() call, " << "but GetArgStructInfo was given " << sizeof...(ArgTypes) << " template arguments."; @@ -213,10 +213,10 @@ inline StructInfo InferStructInfoUnary(const Call& call, const BlockBuilder& ctx output_sinfo->dtype = f_compute_out_dtype(input_sinfo); if (call->sinfo_args.size() > 0) { auto defined_sinfo = call->sinfo_args[0].as(); - ICHECK(defined_sinfo); + TVM_FFI_ICHECK(defined_sinfo); auto shape = output_sinfo->GetShape(); - ICHECK(shape.defined()); - ICHECK(defined_sinfo->vdevice.has_value()); + TVM_FFI_ICHECK(shape.defined()); + TVM_FFI_ICHECK(defined_sinfo->vdevice.has_value()); return TensorStructInfo(ShapeExpr(shape.value()), output_sinfo->dtype, defined_sinfo->vdevice.value()); } else { @@ -286,10 +286,9 @@ inline std::optional GetElementDType(const StructInfo& sinfo) { return tensor->dtype; } else { return std::nullopt; - LOG(FATAL) << "TypeError: " - << "Only PrimStructInfo and TensorStructInfo " - << "have an associated data type. " - << "Cannot determine element type of " << sinfo; + TVM_FFI_THROW(TypeError) << "Only PrimStructInfo and TensorStructInfo " + << "have an associated data type. " + << "Cannot determine element type of " << sinfo; } } @@ -307,8 +306,7 @@ inline DataType InferBinaryArithOpOutDtype(const Call& call, const BlockBuilder& const StructInfo& rhs_sinfo) { auto opt_lhs_dtype = GetElementDType(lhs_sinfo); if (!opt_lhs_dtype) { - ctx->ReportFatal(Diagnostic::Error(call) - << "TypeError: " + ctx->ReportFatal(Diagnostic::Error("TypeError", call) << "Binary operators must have the same datatype for both operands. " << "However, " << call << " has argument " << call->args[0] << " on the LHS, with struct info " << lhs_sinfo << ". This is of type " @@ -318,8 +316,7 @@ inline DataType InferBinaryArithOpOutDtype(const Call& call, const BlockBuilder& auto opt_rhs_dtype = GetElementDType(rhs_sinfo); if (!opt_rhs_dtype) { - ctx->ReportFatal(Diagnostic::Error(call) - << "TypeError: " + ctx->ReportFatal(Diagnostic::Error("TypeError", call) << "Binary operators must have the same datatype for both operands. " << "However, " << call << " has argument " << call->args[1] << " on the RHS, with struct info " << rhs_sinfo << ". This is of type " @@ -330,8 +327,7 @@ inline DataType InferBinaryArithOpOutDtype(const Call& call, const BlockBuilder& if (lhs_dtype.is_void() || rhs_dtype.is_void()) { return DataType::Void(); } else if (lhs_dtype != rhs_dtype && !lhs_dtype.is_bool() && !rhs_dtype.is_bool()) { - ctx->ReportFatal(Diagnostic::Error(call) - << "TypeError: " + ctx->ReportFatal(Diagnostic::Error("TypeError", call) << "Binary operators must have the same datatype for both operands. " << "However, " << call << " uses datatype " << lhs_dtype << " on the LHS (StructInfo of " << lhs_sinfo << "), and datatype " @@ -381,8 +377,7 @@ inline ffi::Optional InferBinaryArithOpOutVDevice(const Call& call, } if (lhs_vdevice.value() != rhs_vdevice.value()) { - ctx->ReportFatal(Diagnostic::Error(call) - << "TypeErorr: " + ctx->ReportFatal(Diagnostic::Error("TypeError", call) << "Binary operators with Tensor arguments " << "must have the same VDevice for both operands. " << "However, " << call << " has a LHS on VDevice " << lhs_vdevice @@ -471,9 +466,10 @@ inline ffi::Array GetCompletePadding1D(ffi::Array padding) { } else if (padding.size() == 2) { return padding; } - LOG(FATAL) << "The input padding length is expected to be either 1 or 2. However, the given " - "padding is " - << padding; + TVM_FFI_THROW(InternalError) + << "The input padding length is expected to be either 1 or 2. However, the given " + "padding is " + << padding; throw; } @@ -494,9 +490,10 @@ inline ffi::Array GetCompletePadding2D(ffi::Array padding) { } else if (padding.size() == 4) { return padding; } - LOG(FATAL) << "The input padding length is expected to be either 1, 2 or 4. However, the given " - "padding is " - << padding; + TVM_FFI_THROW(InternalError) + << "The input padding length is expected to be either 1, 2 or 4. However, the given " + "padding is " + << padding; throw; } @@ -519,9 +516,10 @@ inline ffi::Array GetCompletePadding3D(ffi::Array padding) { } else if (padding.size() == 6) { return padding; } - LOG(FATAL) << "The input padding length is expected to be either 1, 3 or 6. However, the given " - "padding is " - << padding; + TVM_FFI_THROW(InternalError) + << "The input padding length is expected to be either 1, 3 or 6. However, the given " + "padding is " + << padding; throw; } diff --git a/src/relax/op/tensor/binary.cc b/src/relax/op/tensor/binary.cc index 7051d2b1b975..81a8688d1c17 100644 --- a/src/relax/op/tensor/binary.cc +++ b/src/relax/op/tensor/binary.cc @@ -42,13 +42,13 @@ StructInfo InferStructInfoBroadcast(const Call& call, const BlockBuilder& ctx, auto lhs_sinfo = GetStructInfo(call->args[0]); auto rhs_sinfo = GetStructInfo(call->args[1]); - CHECK(lhs_sinfo.as() || lhs_sinfo.as()) - << "TypeError: " + TVM_FFI_CHECK(lhs_sinfo.as() || lhs_sinfo.as(), + TypeError) << "Arguments to binary operators must be either R.Tensor or R.Prim types, " << "but expression " << call << " has LHS " << call->args[0] << ", which has StructInfo " << lhs_sinfo; - CHECK(rhs_sinfo.as() || rhs_sinfo.as()) - << "TypeError: " + TVM_FFI_CHECK(rhs_sinfo.as() || rhs_sinfo.as(), + TypeError) << "Arguments to binary operators must be either R.Tensor or R.Prim types, " << "but expression " << call << " has RHS " << call->args[1] << ", which has StructInfo " << rhs_sinfo; @@ -104,7 +104,7 @@ StructInfo InferStructInfoBroadcast(const Call& call, const BlockBuilder& ctx, ffi::Optional> output_shape = InferBinaryBroadcastShape(call, ctx, lhs_shape.value(), rhs_shape.value()); if (output_shape.defined()) { - ICHECK_EQ(static_cast(output_shape.value().size()), output_ndim); + TVM_FFI_ICHECK_EQ(static_cast(output_shape.value().size()), output_ndim); return TensorStructInfo(ShapeExpr(output_shape.value()), output_dtype, vdevice); } } @@ -145,14 +145,14 @@ StructInfo InferStructInfoBroadcastCMP(const Call& call, const BlockBuilder& ctx InferLayoutOutput InferLayoutBinaryEwise( const Call& call, const ffi::Map>& desired_layouts, const VarLayoutMap& var_layout_map) { - ICHECK(NoDesiredLayout(call, desired_layouts)); + TVM_FFI_ICHECK(NoDesiredLayout(call, desired_layouts)); LayoutDecision layout1 = GetLayoutDecision(var_layout_map, call->args[0]); LayoutDecision layout2 = GetLayoutDecision(var_layout_map, call->args[1]); auto* x1_sinfo = GetStructInfoAs(call->args[0]); auto* x2_sinfo = GetStructInfoAs(call->args[1]); - ICHECK(!x1_sinfo->IsUnknownNdim() && !x2_sinfo->IsUnknownNdim()) + TVM_FFI_ICHECK(!x1_sinfo->IsUnknownNdim() && !x2_sinfo->IsUnknownNdim()) << "Unknown dim tensors should not be handled by this function"; ffi::Optional shape1 = ffi::GetRef(x1_sinfo->shape.as()); diff --git a/src/relax/op/tensor/create.cc b/src/relax/op/tensor/create.cc index a9a0872d683a..88ea67ab99e1 100644 --- a/src/relax/op/tensor/create.cc +++ b/src/relax/op/tensor/create.cc @@ -51,7 +51,7 @@ Expr full(ffi::Variant> shape, Expr fill_value, } else if (const auto* _array = shape.as()) { shape_in_expr = ShapeExpr(ffi::GetRef>(_array)); } else { - LOG(FATAL) + TVM_FFI_THROW(InternalError) << "Full only expects the input shape to be either an Expr or an Array of PrimExpr. "; } @@ -173,7 +173,7 @@ StructInfo InferStructInfoOnesLikeZerosLike(const Call& call, const BlockBuilder /* relax.ones & relax.ones_like */ Expr ones(Expr shape, DataType dtype) { - CHECK(!dtype.is_void()) << "Ones op expects the input dtype not to be void"; + TVM_FFI_ICHECK(!dtype.is_void()) << "Ones op expects the input dtype not to be void"; ObjectPtr attrs = ffi::make_object(); attrs->dtype = dtype; @@ -210,7 +210,7 @@ TVM_REGISTER_OP("relax.ones_like") /* relax.zeros & relax.zeros_like */ Expr zeros(Expr shape, DataType dtype) { - CHECK(!dtype.is_void()) << "Zeros op expects the input dtype not to be void"; + TVM_FFI_ICHECK(!dtype.is_void()) << "Zeros op expects the input dtype not to be void"; ObjectPtr attrs = ffi::make_object(); attrs->dtype = dtype; diff --git a/src/relax/op/tensor/index.cc b/src/relax/op/tensor/index.cc index e59ba7b597c8..0dfb531f2b7f 100644 --- a/src/relax/op/tensor/index.cc +++ b/src/relax/op/tensor/index.cc @@ -148,7 +148,7 @@ Expr strided_slice(Expr x, Expr axes, Expr begin, Expr end, ffi::Optional size_t length = tuple->fields.size(); if (known_length.has_value()) { const auto& prev = known_length.value(); - CHECK_EQ(length, std::get(prev)) + TVM_FFI_ICHECK_EQ(length, std::get(prev)) << "The strided_slice operator requires that " << "the axes, begin, end, and strides tuples are all the same length. " << "However, the " << std::get(prev) << " argument (" @@ -215,9 +215,9 @@ ffi::Optional> UnpackTupleOfPrimValue(ffi::Optional()) return std::nullopt; auto tuple = sinfo.as(); - CHECK(tuple) << "TypeError: " - << "The struct info " << sinfo << " cannot contain a tuple whose elements are " - << PrimType::ContainerType::_type_key; + TVM_FFI_CHECK(tuple, TypeError) << "The struct info " << sinfo + << " cannot contain a tuple whose elements are " + << PrimType::ContainerType::_type_key; ffi::Array output; for (size_t i = 0; i < tuple->fields.size(); i++) { @@ -226,11 +226,10 @@ ffi::Optional> UnpackTupleOfPrimValue(ffi::Optional()) return std::nullopt; auto prim_sinfo = field.as(); - CHECK(prim_sinfo) << "TypeError: " - << "The struct info " << sinfo - << " cannot contain a tuple whose elements are " - << PrimType::ContainerType::_type_key << ", because element " << i - << " has struct info " << field; + TVM_FFI_CHECK(prim_sinfo, TypeError) + << "The struct info " << sinfo << " cannot contain a tuple whose elements are " + << PrimType::ContainerType::_type_key << ", because element " << i << " has struct info " + << field; if (!prim_sinfo->value.defined()) return std::nullopt; @@ -275,7 +274,7 @@ ffi::Optional> UnpackTupleOfPrimValue(ffi::Optional e StructInfo InferStructInfoStridedSlice(const Call& call, const BlockBuilder& ctx) { size_t n_args = call->args.size(); - CHECK(4 <= n_args && n_args <= 5) + TVM_FFI_ICHECK(4 <= n_args && n_args <= 5) << "Operator " << call->op << " accepts either three arguments (data, axes, begin, end) " << " or four arguments (data, axes, begin, end, strides), " << "but received " << n_args << " in expression " << call; @@ -303,7 +302,8 @@ StructInfo InferStructInfoStridedSlice(const Call& call, const BlockBuilder& ctx } }(); - CHECK(IsBaseOf(relax::TensorStructInfo(DataType::Void(), kUnknownNDim), GetStructInfo(data))) + TVM_FFI_ICHECK( + IsBaseOf(relax::TensorStructInfo(DataType::Void(), kUnknownNDim), GetStructInfo(data))) << "Operator " << call->op << " requires the first argument to be a tensor. " << "However, in expression " << call << ", the first argument " << data << " has struct info " << GetStructInfo(data); @@ -326,10 +326,11 @@ StructInfo InferStructInfoStridedSlice(const Call& call, const BlockBuilder& ctx auto check_tuple = [&](const char* name, Expr expr) { auto sinfo = GetStructInfo(expr); - CHECK(is_base_of_tuple_of_int64(sinfo)) << "Operator " << call->op << " requires the " << name - << " argument to be a tuple of int64 PrimValues. " - << "However, in expression " << call << ", the " << name - << " argument " << expr << " has struct info " << sinfo; + TVM_FFI_ICHECK(is_base_of_tuple_of_int64(sinfo)) + << "Operator " << call->op << " requires the " << name + << " argument to be a tuple of int64 PrimValues. " + << "However, in expression " << call << ", the " << name << " argument " << expr + << " has struct info " << sinfo; }; check_tuple("axes", call->args[1]); check_tuple("begin", call->args[2]); @@ -361,7 +362,7 @@ StructInfo InferStructInfoStridedSlice(const Call& call, const BlockBuilder& ctx if (!opt_begin_tuple) return std::nullopt; auto begin_tuple = opt_begin_tuple.value(); - CHECK_EQ(axes_tuple.size(), begin_tuple.size()) + TVM_FFI_ICHECK_EQ(axes_tuple.size(), begin_tuple.size()) << "For operator " << call->op << ", " << "the number of axes provided must match the number of 'begin' indices. " << "However, there are " << axes_tuple.size() << " axes specified (" << axes_tuple @@ -371,7 +372,7 @@ StructInfo InferStructInfoStridedSlice(const Call& call, const BlockBuilder& ctx if (!opt_end_tuple) return std::nullopt; auto end_tuple = opt_end_tuple.value(); - CHECK_EQ(axes_tuple.size(), end_tuple.size()) + TVM_FFI_ICHECK_EQ(axes_tuple.size(), end_tuple.size()) << "For operator " << call->op << ", " << "the number of axes provided must match the number of 'end' indices. " << "However, there are " << axes_tuple.size() << " axes specified (" << axes_tuple @@ -387,7 +388,7 @@ StructInfo InferStructInfoStridedSlice(const Call& call, const BlockBuilder& ctx strides_tuple = ffi::Array(axes_tuple.size(), IntImm(DataType::Int(64), 1)); } - CHECK_EQ(axes_tuple.size(), strides_tuple.size()) + TVM_FFI_ICHECK_EQ(axes_tuple.size(), strides_tuple.size()) << "For operator " << call->op << ", " << "when the optional 'strides' argument is provided, " << "the number of axes provided must match the number of strides provided. " @@ -438,16 +439,17 @@ StructInfo InferStructInfoStridedSlice(const Call& call, const BlockBuilder& ctx InferLayoutOutput InferLayoutStridedSlice( const Call& call, const ffi::Map>& desired_layouts, const VarLayoutMap& var_layout_map) { - ICHECK(NoDesiredLayout(call, desired_layouts)); + TVM_FFI_ICHECK(NoDesiredLayout(call, desired_layouts)); const auto* attrs = call->attrs.as(); - ICHECK(attrs != nullptr) << "Invalid Call"; + TVM_FFI_ICHECK(attrs != nullptr) << "Invalid Call"; const auto* tensor_sinfo = GetStructInfoAs(call->args[0]); - CHECK(tensor_sinfo) << "Invalid Call"; - CHECK(!tensor_sinfo->IsUnknownNdim()) << "Layout inference only supports known dimensionality, " - << "but expression " << call << " has argument " - << call->args[0] << " of unknown dimensionality."; + TVM_FFI_ICHECK(tensor_sinfo) << "Invalid Call"; + TVM_FFI_ICHECK(!tensor_sinfo->IsUnknownNdim()) + << "Layout inference only supports known dimensionality, " + << "but expression " << call << " has argument " << call->args[0] + << " of unknown dimensionality."; LayoutDecision existing_layout = GetLayoutDecision(var_layout_map, call->args[0]); // Can't handle sub indexed layouts. if (existing_layout->layout.ndim() != existing_layout->layout.ndim_primal()) { @@ -455,10 +457,10 @@ InferLayoutOutput InferLayoutStridedSlice( } auto opt_axes_tuple = UnpackTupleOfPrimValue(GetStructInfo(call->args[1])); - CHECK(opt_axes_tuple) << "Layout inference of " << call->op - << " requires slices to be along static axes. " - << "However, expression " << call << " slices along non-static axes " - << call->args[1]; + TVM_FFI_ICHECK(opt_axes_tuple) << "Layout inference of " << call->op + << " requires slices to be along static axes. " + << "However, expression " << call + << " slices along non-static axes " << call->args[1]; ffi::Array axes_tuple = opt_axes_tuple.value(); ffi::Array new_axes; @@ -500,7 +502,7 @@ StructInfo InferStructInfoDynStridedSlice(const Call& call, const BlockBuilder& const auto* end_sinfo = GetStructInfoAs(call->args[2]); const auto* strides_sinfo = GetStructInfoAs(call->args[3]); - ICHECK(data_sinfo); + TVM_FFI_ICHECK(data_sinfo); if (data_sinfo->IsUnknownNdim()) { LOG(WARNING) << "When data rank is unknown, dynamic strided slice assumes begin/end/strides " "tensors are well-formed. It could produce runtime error when this assumption " @@ -515,25 +517,26 @@ StructInfo InferStructInfoDynStridedSlice(const Call& call, const BlockBuilder& int n_axis = data_sinfo->ndim; auto diag_def = [&](const TensorStructInfoNode* sinfo, ffi::String name) { - ICHECK(sinfo) << "Dynamic strided slice requires the input " << name - << " to be have the struct info. Please try normalizing the inputs."; - CHECK_EQ(sinfo->ndim, 1) << "Dynamic strided slice requires " << name - << " to be 1d tensor (list of values)."; + TVM_FFI_ICHECK(sinfo) << "Dynamic strided slice requires the input " << name + << " to be have the struct info. Please try normalizing the inputs."; + TVM_FFI_ICHECK_EQ(sinfo->ndim, 1) + << "Dynamic strided slice requires " << name << " to be 1d tensor (list of values)."; const auto* shape = sinfo->shape.as(); - ICHECK(shape) << "Dynamic strided slice requires the input " << name - << " to have well-defined shape."; + TVM_FFI_ICHECK(shape) << "Dynamic strided slice requires the input " << name + << " to have well-defined shape."; // NOTE(tvm-team): This strong restriction seems necessary for now until we have a generic // solution in converting 1d Tensor with unknown num_elem to ffi::Array. const auto* num_elem = shape->values[0].as(); - ICHECK(num_elem) << "Dynamic strided slice requires the input " << name - << " to have a known integer shape value."; - CHECK_EQ(num_elem->value, n_axis) << "Dynamic strided slice requires the number of indices in " - << name << " to equal the number of axes."; + TVM_FFI_ICHECK(num_elem) << "Dynamic strided slice requires the input " << name + << " to have a known integer shape value."; + TVM_FFI_ICHECK_EQ(num_elem->value, n_axis) + << "Dynamic strided slice requires the number of indices in " << name + << " to equal the number of axes."; if (sinfo->IsUnknownDtype()) { LOG(WARNING) << "Dynamic strided slice assumes " << name << " to be int64 when it is not specified."; } else { - CHECK(sinfo->dtype == DataType::Int(64)) + TVM_FFI_ICHECK(sinfo->dtype == DataType::Int(64)) << "Dynamic strided_slice expects the input " << name << "values to be all int64. However, " << name << " has dtype " << sinfo->dtype << "."; } @@ -551,13 +554,14 @@ StructInfo InferStructInfoDynStridedSlice(const Call& call, const BlockBuilder& InferLayoutOutput InferLayoutDynStridedSlice( const Call& call, const ffi::Map>& desired_layouts, const VarLayoutMap& var_layout_map) { - ICHECK(NoDesiredLayout(call, desired_layouts)); + TVM_FFI_ICHECK(NoDesiredLayout(call, desired_layouts)); const auto* tensor_sinfo = GetStructInfoAs(call->args[0]); - CHECK(tensor_sinfo) << "Invalid Call"; - CHECK(!tensor_sinfo->IsUnknownNdim()) << "Layout inference only supports known dimensionality, " - << "but expression " << call << " has argument " - << call->args[0] << " of unknown dimensionality."; + TVM_FFI_ICHECK(tensor_sinfo) << "Invalid Call"; + TVM_FFI_ICHECK(!tensor_sinfo->IsUnknownNdim()) + << "Layout inference only supports known dimensionality, " + << "but expression " << call << " has argument " << call->args[0] + << " of unknown dimensionality."; int ndim = tensor_sinfo->ndim; // Since begin/end/strides are dynamic tensors, we cannot transform // them at compile time. Fall back to the initial layout. diff --git a/src/relax/op/tensor/inspect.cc b/src/relax/op/tensor/inspect.cc index 01843ba0a3c0..2e05ea4a81bc 100644 --- a/src/relax/op/tensor/inspect.cc +++ b/src/relax/op/tensor/inspect.cc @@ -36,50 +36,47 @@ namespace relax { namespace inspect { TensorStructInfo GetTensorArgInfo(const Call& call) { - CHECK_EQ(call->args.size(), 1) << "TypeError: " - << "Operator " << call->op << " expects one argument, " - << "but received " << call->args.size() - << " arguments: " << call->args; + TVM_FFI_CHECK_EQ(call->args.size(), 1, TypeError) + << "Operator " << call->op << " expects one argument, " + << "but received " << call->args.size() << " arguments: " << call->args; const auto& arg = call->args[0]; auto sinfo = GetStructInfo(arg); auto tensor_sinfo = sinfo.as(); - CHECK(tensor_sinfo) << "TypeError: " - << "Operator " << call->op << " expects a tensor argument, " - << "but argument " << arg << " has struct info " << sinfo; + TVM_FFI_CHECK(tensor_sinfo, TypeError) + << "Operator " << call->op << " expects a tensor argument, " + << "but argument " << arg << " has struct info " << sinfo; return tensor_sinfo.value(); } std::tuple GetTensorArgInfoWithIndex(const Call& call) { - CHECK_EQ(call->args.size(), 2) << "TypeError: " - << "Operator " << call->op << " expects two arguments, " - << "but received " << call->args.size() - << " arguments: " << call->args; + TVM_FFI_CHECK_EQ(call->args.size(), 2, TypeError) + << "Operator " << call->op << " expects two arguments, " + << "but received " << call->args.size() << " arguments: " << call->args; const auto& arg = call->args[0]; const auto& axis = call->args[1]; auto tensor_sinfo = arg->struct_info_.as(); - CHECK(tensor_sinfo) << "TypeError: " - << "Operator " << call->op << " expects arguments (tensor, axis), " - << "but the first argument " << arg << " in expression " << call - << " has struct info " << arg->struct_info_; + TVM_FFI_CHECK(tensor_sinfo, TypeError) + << "Operator " << call->op << " expects arguments (tensor, axis), " + << "but the first argument " << arg << " in expression " << call << " has struct info " + << arg->struct_info_; auto axis_sinfo = axis->struct_info_.as(); - CHECK(axis_sinfo) << "TypeError: " - << "Operator " << call->op << " expects arguments (tensor, axis), " - << "but the second argument " << arg << " in expression " << call - << " has struct info " << axis->struct_info_; + TVM_FFI_CHECK(axis_sinfo, TypeError) + << "Operator " << call->op << " expects arguments (tensor, axis), " + << "but the second argument " << arg << " in expression " << call << " has struct info " + << axis->struct_info_; auto int_imm_axis = axis_sinfo->value.as(); if (int_imm_axis) { - CHECK_GE(int_imm_axis->value, 0); + TVM_FFI_ICHECK_GE(int_imm_axis->value, 0); } if (int_imm_axis && !tensor_sinfo->IsUnknownNdim()) { - CHECK_LT(int_imm_axis->value, tensor_sinfo->ndim) - << "ValueError: " + TVM_FFI_CHECK_LT(int_imm_axis->value, tensor_sinfo->ndim, ValueError) << "Expression " << call << " attempts to access " << arg << ".shape[" << int_imm_axis->value << "]" << ", but " << arg << ".shape only has " << tensor_sinfo->ndim << " elements"; diff --git a/src/relax/op/tensor/linear_algebra.cc b/src/relax/op/tensor/linear_algebra.cc index 06b7856dd239..fa5c7e339862 100644 --- a/src/relax/op/tensor/linear_algebra.cc +++ b/src/relax/op/tensor/linear_algebra.cc @@ -153,7 +153,7 @@ StructInfo InferStructInfoMatmul(const Call& call, const BlockBuilder& ctx) { if (!x2_appended) { output_shape.push_back(x2_shape->values[x2_ndim - 1]); } - ICHECK_EQ(static_cast(output_shape.size()), output_ndim); + TVM_FFI_ICHECK_EQ(static_cast(output_shape.size()), output_ndim); if (vdev.defined()) { return TensorStructInfo(ShapeExpr(output_shape), out_dtype, vdev); } diff --git a/src/relax/op/tensor/manipulate.cc b/src/relax/op/tensor/manipulate.cc index d4bbbd208ba7..e5f3d19e8dd9 100644 --- a/src/relax/op/tensor/manipulate.cc +++ b/src/relax/op/tensor/manipulate.cc @@ -330,14 +330,14 @@ StructInfo InferStructInfoConcat(const Call& call, const BlockBuilder& ctx) { InferLayoutOutput InferLayoutConcat( const Call& call, const ffi::Map>& desired_layouts, const VarLayoutMap& var_layout_map) { - ICHECK(NoDesiredLayout(call, desired_layouts)); + TVM_FFI_ICHECK(NoDesiredLayout(call, desired_layouts)); const auto* attrs = call->attrs.as(); - ICHECK(attrs != nullptr) << "Invalid Call"; + TVM_FFI_ICHECK(attrs != nullptr) << "Invalid Call"; NLayout nlayout = GetNLayout(var_layout_map, call->args[0]); - ICHECK(nlayout.IsNested()); - ICHECK(nlayout.NestedArray()[0].IsLeaf()); + TVM_FFI_ICHECK(nlayout.IsNested()); + TVM_FFI_ICHECK(nlayout.NestedArray()[0].IsLeaf()); int n_tensor = nlayout.NestedArray().size(); LayoutDecision layout = nlayout.NestedArray()[0].LeafValue(); @@ -347,17 +347,17 @@ InferLayoutOutput InferLayoutConcat( // On any failre select first occuring regular layout for all auto nlayout_array = nlayout.NestedArray(); for (auto n_layout : nlayout_array) { - ICHECK(n_layout.IsLeaf()); + TVM_FFI_ICHECK(n_layout.IsLeaf()); LayoutDecision in_layout = n_layout.LeafValue(); if (in_layout->layout.ndim() != in_layout->layout.ndim_primal()) { const auto* tuple_sinfo = GetStructInfoAs(call->args[0]); - ICHECK(tuple_sinfo != nullptr) + TVM_FFI_ICHECK(tuple_sinfo != nullptr) << " expects the input to be a Tuple of Tensors. However, the given input is " << call->args[0]->struct_info_->GetTypeKey(); for (size_t i = 0; i < tuple_sinfo->fields.size(); ++i) { StructInfo field_sinfo = tuple_sinfo->fields[i]; const auto* field_tensor_sinfo = field_sinfo.as(); - ICHECK(field_tensor_sinfo != nullptr) + TVM_FFI_ICHECK(field_tensor_sinfo != nullptr) << call->op << " expects the input to be a Tuple of Tensors. However, the given input is " << call->args[0]->struct_info_; @@ -448,23 +448,23 @@ StructInfo InferStructInfoExpandDims(const Call& call, const BlockBuilder& ctx) if (output_shape[i].defined()) { continue; } - ICHECK_LT(i_data_shape, data_sinfo->ndim); + TVM_FFI_ICHECK_LT(i_data_shape, data_sinfo->ndim); output_shape[i] = data_shape->values[i_data_shape]; ++i_data_shape; } - ICHECK_EQ(i_data_shape, data_sinfo->ndim); + TVM_FFI_ICHECK_EQ(i_data_shape, data_sinfo->ndim); return TensorStructInfo(ShapeExpr(output_shape), data_sinfo->dtype, data_sinfo->vdevice); } InferLayoutOutput InferLayoutExpandDims( const Call& call, const ffi::Map>& desired_layouts, const VarLayoutMap& var_layout_map) { - ICHECK(NoDesiredLayout(call, desired_layouts)); + TVM_FFI_ICHECK(NoDesiredLayout(call, desired_layouts)); const auto* attrs = call->attrs.as(); - ICHECK(attrs != nullptr) << "Invalid Call"; + TVM_FFI_ICHECK(attrs != nullptr) << "Invalid Call"; const auto* tensor_sinfo = GetStructInfoAs(call->args[0]); - ICHECK(tensor_sinfo != nullptr) << "Invalid Call"; - ICHECK(!tensor_sinfo->IsUnknownNdim()) << "Only support static ndim for now"; + TVM_FFI_ICHECK(tensor_sinfo != nullptr) << "Invalid Call"; + TVM_FFI_ICHECK(!tensor_sinfo->IsUnknownNdim()) << "Only support static ndim for now"; LayoutDecision existing_layout = GetLayoutDecision(var_layout_map, call->args[0]); int ndim = tensor_sinfo->ndim; @@ -847,13 +847,13 @@ StructInfo InferStructInfoPermuteDims(const Call& call, const BlockBuilder& ctx) InferLayoutOutput InferLayoutPermuteDims( const Call& call, const ffi::Map>& desired_layouts, const VarLayoutMap& var_layout_map) { - ICHECK(NoDesiredLayout(call, desired_layouts)); + TVM_FFI_ICHECK(NoDesiredLayout(call, desired_layouts)); const auto* attrs = call->attrs.as(); - ICHECK(attrs != nullptr) << "Invalid Call"; + TVM_FFI_ICHECK(attrs != nullptr) << "Invalid Call"; const auto* tensor_sinfo = GetStructInfoAs(call->args[0]); - ICHECK(tensor_sinfo != nullptr) << "Invalid Call"; - ICHECK(!tensor_sinfo->IsUnknownNdim()) << "Only support static ndim for now"; + TVM_FFI_ICHECK(tensor_sinfo != nullptr) << "Invalid Call"; + TVM_FFI_ICHECK(!tensor_sinfo->IsUnknownNdim()) << "Only support static ndim for now"; int ndim = tensor_sinfo->ndim; LayoutDecision existing_layout = GetLayoutDecision(var_layout_map, call->args[0]); @@ -910,32 +910,35 @@ Expr ConvertNewShapeToExpr(const Expr& data, } else { array = shape.as(); } - CHECK(array != nullptr) << "Reshape only expects the input new shape to be either an Expr or an " - "Array of PrimExprs. However, the given new shape is " - << shape; + TVM_FFI_ICHECK(array != nullptr) + << "Reshape only expects the input new shape to be either an Expr or an " + "Array of PrimExprs. However, the given new shape is " + << shape; int dim_to_infer = -1; // Keep track of which dimensions should be copied from input. std::vector zero_dims; for (int i = 0; i < static_cast(array->size()); ++i) { const auto* _len = array->at(i).as(); - CHECK(_len != nullptr) << "Reshape only expects the input new shape to be either an Expr or an " - "Array of PrimExprs. However, the given new shape is " - << shape; + TVM_FFI_ICHECK(_len != nullptr) + << "Reshape only expects the input new shape to be either an Expr or an " + "Array of PrimExprs. However, the given new shape is " + << shape; PrimExpr len = ffi::GetRef(_len); - CHECK(len->dtype.is_int()) << "Reshape requires the new shape values to be all " - "integers. However, the give new shape is " - << shape; + TVM_FFI_ICHECK(len->dtype.is_int()) << "Reshape requires the new shape values to be all " + "integers. However, the give new shape is " + << shape; const auto* int_len = len.as(); if (int_len != nullptr && int_len->value == 0) { // Note that this dimension should be copied from the original shape. zero_dims.push_back(i); } else if (int_len != nullptr && int_len->value == -1) { - CHECK_EQ(dim_to_infer, -1) << "Reshape accepts at most one \"-1\" in the new shape. However, " - "there are multiple \"-1\" in the given new shape " - << shape; + TVM_FFI_ICHECK_EQ(dim_to_infer, -1) + << "Reshape accepts at most one \"-1\" in the new shape. However, " + "there are multiple \"-1\" in the given new shape " + << shape; dim_to_infer = i; } else { - CHECK(int_len == nullptr || int_len->value > 0) + TVM_FFI_ICHECK(int_len == nullptr || int_len->value > 0) << "Reshape requires all values in the new shape to be positive except a single \"-1\". " "However, the given new shape is " << shape; @@ -950,14 +953,14 @@ Expr ConvertNewShapeToExpr(const Expr& data, // Otherwise, we require the input tensor to have known shape value for inference. const auto* data_sinfo = GetStructInfoAs(data); - CHECK(data_sinfo != nullptr) + TVM_FFI_ICHECK(data_sinfo != nullptr) << "Reshape expects the input data to be a Tensor. However, the given input is " << data->struct_info_->GetTypeKey(); - CHECK(data_sinfo->shape.defined()) + TVM_FFI_ICHECK(data_sinfo->shape.defined()) << "Reshape expects the input tensor to have known shape when there is some dimension length " "to infer. However, the given input has no shape."; const auto* shape_sinfo = GetStructInfoAs(data_sinfo->shape.value()); - CHECK(shape_sinfo != nullptr && shape_sinfo->values.defined()) + TVM_FFI_ICHECK(shape_sinfo != nullptr && shape_sinfo->values.defined()) << "Reshape expects the input tensor to have known shape when there is some dimension length " "to infer. However, the given input shape is " << data_sinfo->shape << " whose shape value is unknown."; @@ -1023,7 +1026,7 @@ StructInfo InferStructInfoReshape(const Call& call, const BlockBuilder& ctx) { ffi::Optional> old_shape_values; if (data_sinfo->shape.defined()) { const auto* old_shape_sinfo = GetStructInfoAs(data_sinfo->shape.value()); - ICHECK_NOTNULL(old_shape_sinfo); + TVM_FFI_ICHECK_NOTNULL(old_shape_sinfo); old_shape_values = old_shape_sinfo->values; } @@ -1065,19 +1068,22 @@ Expr split(Expr x, ffi::Variant> indices_or_sections, if (const auto* indices = indices_or_sections.as()) { for (int i = 0; i < static_cast(indices->size()); ++i) { const auto* idx = indices->at(i).as(); - CHECK(idx != nullptr) << "Split op only accepts an array of integers as the indices. " - "However, the given indices " - << indices_or_sections << " contains some non-integer."; + TVM_FFI_ICHECK(idx != nullptr) + << "Split op only accepts an array of integers as the indices. " + "However, the given indices " + << indices_or_sections << " contains some non-integer."; } indices_or_sections_obj = ConvertIntImmToInt64(ffi::GetRef>(indices)); } else if (const auto* n_section = indices_or_sections.as()) { - CHECK_GT(n_section->value, 0) << "Split op expects the input number of sections to be a " - "positive integer. However, the given number of sections is " - << n_section->value; + TVM_FFI_ICHECK_GT(n_section->value, 0) + << "Split op expects the input number of sections to be a " + "positive integer. However, the given number of sections is " + << n_section->value; indices_or_sections_obj = IntImm(DataType::Int(64), n_section->value); } else { - LOG(FATAL) << "Split op expects the input indices_or_sections to be either an Array of " - "PrimExpr or an integer."; + TVM_FFI_THROW(InternalError) + << "Split op expects the input indices_or_sections to be either an Array of " + "PrimExpr or an integer."; } attrs->indices_or_sections = indices_or_sections_obj; attrs->axis = axis; @@ -1111,7 +1117,7 @@ StructInfo InferStructInfoSplit(const Call& call, const BlockBuilder& ctx) { TensorStructInfo(data_sinfo->dtype, data_sinfo->ndim, data_sinfo->vdevice))); } - ICHECK_NE(axis, -1); + TVM_FFI_ICHECK_NE(axis, -1); IntImm zero(DataType::Int(64), /*value=*/0); @@ -1145,7 +1151,7 @@ StructInfo InferStructInfoSplit(const Call& call, const BlockBuilder& ctx) { } return TupleStructInfo(output_sinfo); } else if (const auto* p_n_section = attrs->indices_or_sections.as()) { - ICHECK_GT(p_n_section->value, 0); + TVM_FFI_ICHECK_GT(p_n_section->value, 0); int n_section = p_n_section->value; // When the number of section is one, return the input tensor's struct info. if (n_section == 1) { @@ -1156,7 +1162,7 @@ StructInfo InferStructInfoSplit(const Call& call, const BlockBuilder& ctx) { return TupleStructInfo(ffi::Array( n_section, TensorStructInfo(data_sinfo->dtype, data_sinfo->ndim, data_sinfo->vdevice))); } - ICHECK_NE(axis, -1); + TVM_FFI_ICHECK_NE(axis, -1); PrimExpr split_len = ceildiv(data_shape->values[axis], n_section); split_len = ctx->GetAnalyzer()->Simplify(split_len); @@ -1174,20 +1180,20 @@ StructInfo InferStructInfoSplit(const Call& call, const BlockBuilder& ctx) { TensorStructInfo(ShapeExpr(shape), data_sinfo->dtype, data_sinfo->vdevice)); return TupleStructInfo(output_sinfo); } - ICHECK(false) << "Cannot reach here."; + TVM_FFI_ICHECK(false) << "Cannot reach here."; throw; } InferLayoutOutput InferLayoutSplit( const Call& call, const ffi::Map>& desired_layouts, const VarLayoutMap& var_layout_map) { - ICHECK(NoDesiredLayout(call, desired_layouts)); + TVM_FFI_ICHECK(NoDesiredLayout(call, desired_layouts)); const auto* attrs = call->attrs.as(); - ICHECK(attrs != nullptr) << "Invalid Call"; + TVM_FFI_ICHECK(attrs != nullptr) << "Invalid Call"; const auto* tensor_sinfo = GetStructInfoAs(call->args[0]); - ICHECK(tensor_sinfo != nullptr) << "Invalid Call"; - ICHECK(!tensor_sinfo->IsUnknownNdim()) << "Only support known ndim"; + TVM_FFI_ICHECK(tensor_sinfo != nullptr) << "Invalid Call"; + TVM_FFI_ICHECK(!tensor_sinfo->IsUnknownNdim()) << "Only support known ndim"; LayoutDecision existing_layout = GetLayoutDecision(var_layout_map, call->args[0]); StructInfo out_sinfo = InferStructInfoSplit(call, BlockBuilder::Create(IRModule())); @@ -1199,14 +1205,14 @@ InferLayoutOutput InferLayoutSplit( */ if (existing_layout->layout.ndim() != existing_layout->layout.ndim_primal()) { for (const auto& si : out_tuple->fields) { - ICHECK(si->IsInstance()) + TVM_FFI_ICHECK(si->IsInstance()) << "Fields of TupleStructInfo must be TensorStructInfo" "output structinfo, but got " << si; auto sinfo = Downcast(si); ffi::Optional shape_expr = ffi::GetRef(sinfo->shape.as()); - CHECK(shape_expr.defined()); + TVM_FFI_ICHECK(shape_expr.defined()); auto shape_arr = shape_expr.value(); if (!CanProveLayoutTransform(InitialLayout(tensor_sinfo->ndim), existing_layout->layout, shape_arr->values)) { @@ -1218,7 +1224,7 @@ InferLayoutOutput InferLayoutSplit( ObjectPtr new_attrs = ffi::make_object(*attrs); new_attrs->axis = FindAxis(existing_layout->layout, attrs->axis); - ICHECK(out_tuple != nullptr) << "Invalid Call"; + TVM_FFI_ICHECK(out_tuple != nullptr) << "Invalid Call"; NLayout tuple_layouts(ffi::Array(out_tuple->fields.size(), existing_layout)); return InferLayoutOutput({existing_layout}, {tuple_layouts}, Attrs(new_attrs)); } @@ -1326,17 +1332,17 @@ StructInfo InferStructInfoSqueeze(const Call& call, const BlockBuilder& ctx) { InferLayoutOutput InferLayoutSqueeze( const Call& call, const ffi::Map>& desired_layouts, const VarLayoutMap& var_layout_map) { - ICHECK(NoDesiredLayout(call, desired_layouts)); + TVM_FFI_ICHECK(NoDesiredLayout(call, desired_layouts)); const auto* attrs = call->attrs.as(); - ICHECK(attrs != nullptr) << "Invalid Call"; + TVM_FFI_ICHECK(attrs != nullptr) << "Invalid Call"; const auto* tensor_sinfo = GetStructInfoAs(call->args[0]); - ICHECK(tensor_sinfo != nullptr) << "Invalid Call"; - ICHECK(!tensor_sinfo->IsUnknownNdim()) << "Only support static ndim for now"; - ICHECK(tensor_sinfo->shape.defined()) << "Only support static shape for now"; + TVM_FFI_ICHECK(tensor_sinfo != nullptr) << "Invalid Call"; + TVM_FFI_ICHECK(!tensor_sinfo->IsUnknownNdim()) << "Only support static ndim for now"; + TVM_FFI_ICHECK(tensor_sinfo->shape.defined()) << "Only support static shape for now"; int ndim = tensor_sinfo->ndim; const auto* shape = tensor_sinfo->shape.as(); - ICHECK(shape != nullptr) << "Only support static shape for now"; + TVM_FFI_ICHECK(shape != nullptr) << "Only support static shape for now"; ffi::Array axis; if (attrs->axis.defined()) { @@ -1496,7 +1502,7 @@ StructInfo InferStructInfoStack(const Call& call, const BlockBuilder& ctx) { } const auto* attrs = call->attrs.as(); - ICHECK(attrs != nullptr) << "Stack must have StackAttrs"; + TVM_FFI_ICHECK(attrs != nullptr) << "Stack must have StackAttrs"; // Default axis is 0 if not specified int output_ndim = tensor_sinfo[0]->ndim + 1; // Stack adds one dimension @@ -1608,13 +1614,13 @@ StructInfo InferStructInfoStack(const Call& call, const BlockBuilder& ctx) { InferLayoutOutput InferLayoutStack( const Call& call, const ffi::Map>& desired_layouts, const VarLayoutMap& var_layout_map) { - ICHECK(NoDesiredLayout(call, desired_layouts)); + TVM_FFI_ICHECK(NoDesiredLayout(call, desired_layouts)); const auto* attrs = call->attrs.as(); - ICHECK(attrs != nullptr) << "Invalid Call"; + TVM_FFI_ICHECK(attrs != nullptr) << "Invalid Call"; NLayout nlayout = GetNLayout(var_layout_map, call->args[0]); - ICHECK(nlayout.IsNested()); - ICHECK(nlayout.NestedArray()[0].IsLeaf()); + TVM_FFI_ICHECK(nlayout.IsNested()); + TVM_FFI_ICHECK(nlayout.NestedArray()[0].IsLeaf()); int n_tensor = nlayout.NestedArray().size(); LayoutDecision layout = nlayout.NestedArray()[0].LeafValue(); @@ -1808,13 +1814,13 @@ StructInfo InferStructInfoRepeat(const Call& call, const BlockBuilder& ctx) { InferLayoutOutput InferLayoutRepeat( const Call& call, const ffi::Map>& desired_layouts, const VarLayoutMap& var_layout_map) { - ICHECK(NoDesiredLayout(call, desired_layouts)); + TVM_FFI_ICHECK(NoDesiredLayout(call, desired_layouts)); const auto* attrs = call->attrs.as(); - ICHECK(attrs != nullptr) << "Invalid Call"; + TVM_FFI_ICHECK(attrs != nullptr) << "Invalid Call"; const auto* tensor_sinfo = GetStructInfoAs(call->args[0]); - ICHECK(tensor_sinfo != nullptr) << "Invalid Call"; - ICHECK(!tensor_sinfo->IsUnknownNdim()) << "Only support static ndim for now"; + TVM_FFI_ICHECK(tensor_sinfo != nullptr) << "Invalid Call"; + TVM_FFI_ICHECK(!tensor_sinfo->IsUnknownNdim()) << "Only support static ndim for now"; LayoutDecision existing_layout = GetLayoutDecision(var_layout_map, call->args[0]); int ndim = tensor_sinfo->ndim; @@ -1854,7 +1860,7 @@ InferLayoutOutput InferLayoutRepeat( break; } } - ICHECK_GE(new_axis, 0) << "Failed to find transformed axis"; + TVM_FFI_ICHECK_GE(new_axis, 0) << "Failed to find transformed axis"; ObjectPtr new_attrs = ffi::make_object(*attrs); new_attrs->axis = new_axis; @@ -1932,13 +1938,13 @@ StructInfo InferStructInfoTile(const Call& call, const BlockBuilder& ctx) { InferLayoutOutput InferLayoutTile( const Call& call, const ffi::Map>& desired_layouts, const VarLayoutMap& var_layout_map) { - ICHECK(NoDesiredLayout(call, desired_layouts)); + TVM_FFI_ICHECK(NoDesiredLayout(call, desired_layouts)); const auto* attrs = call->attrs.as(); - ICHECK(attrs != nullptr) << "Invalid Call"; + TVM_FFI_ICHECK(attrs != nullptr) << "Invalid Call"; const auto* tensor_sinfo = GetStructInfoAs(call->args[0]); - ICHECK(tensor_sinfo != nullptr) << "Invalid Call"; - ICHECK(!tensor_sinfo->IsUnknownNdim()) << "Only support static ndim for now"; + TVM_FFI_ICHECK(tensor_sinfo != nullptr) << "Invalid Call"; + TVM_FFI_ICHECK(!tensor_sinfo->IsUnknownNdim()) << "Only support static ndim for now"; LayoutDecision existing_layout = GetLayoutDecision(var_layout_map, call->args[0]); int ndim = tensor_sinfo->ndim; @@ -1970,7 +1976,7 @@ InferLayoutOutput InferLayoutTile( for (int i = 0; i < ndim; ++i) { const tir::LayoutAxis& axis = existing_layout_obj[i]; int pos_in_initial = initial_layout.IndexOf(axis); - ICHECK_NE(pos_in_initial, -1) << "Axis not found in initial layout"; + TVM_FFI_ICHECK_NE(pos_in_initial, -1) << "Axis not found in initial layout"; // If len(repeats) < ndim, repeats are right-aligned. // pos_in_initial >= (ndim - l) means it's within the repeats array range. if (pos_in_initial >= ndim - l) { @@ -1982,7 +1988,7 @@ InferLayoutOutput InferLayoutTile( } else { // Different dimension: handle dimension expansion. // This case only happens when l > ndim. - ICHECK_GT(l, ndim); + TVM_FFI_ICHECK_GT(l, ndim); int num_new_dims = l - ndim; // Repeats for new dimensions are not affected by layout change. for (int i = 0; i < num_new_dims; ++i) { @@ -1992,7 +1998,7 @@ InferLayoutOutput InferLayoutTile( for (int i = 0; i < ndim; ++i) { const tir::LayoutAxis& axis = existing_layout_obj[i]; int pos_in_initial = initial_layout.IndexOf(axis); - ICHECK_NE(pos_in_initial, -1) << "Axis not found in initial layout"; + TVM_FFI_ICHECK_NE(pos_in_initial, -1) << "Axis not found in initial layout"; new_repeats.push_back(attrs->repeats[pos_in_initial + num_new_dims]); } } @@ -2050,13 +2056,13 @@ StructInfo InferStructInfoFlip(const Call& call, const BlockBuilder& ctx) { InferLayoutOutput InferLayoutFlip( const Call& call, const ffi::Map>& desired_layouts, const VarLayoutMap& var_layout_map) { - ICHECK(NoDesiredLayout(call, desired_layouts)); + TVM_FFI_ICHECK(NoDesiredLayout(call, desired_layouts)); const auto* attrs = call->attrs.as(); - ICHECK(attrs != nullptr) << "Invalid Call"; + TVM_FFI_ICHECK(attrs != nullptr) << "Invalid Call"; const auto* tensor_sinfo = GetStructInfoAs(call->args[0]); - ICHECK(tensor_sinfo != nullptr) << "Invalid Call"; - ICHECK(!tensor_sinfo->IsUnknownNdim()) << "Only support static ndim for now"; + TVM_FFI_ICHECK(tensor_sinfo != nullptr) << "Invalid Call"; + TVM_FFI_ICHECK(!tensor_sinfo->IsUnknownNdim()) << "Only support static ndim for now"; LayoutDecision existing_layout = GetLayoutDecision(var_layout_map, call->args[0]); int ndim = tensor_sinfo->ndim; @@ -2071,7 +2077,7 @@ InferLayoutOutput InferLayoutFlip( } const int new_axis = FindAxis(existing_layout->layout, axis); - ICHECK_GE(new_axis, 0) << "Failed to find transformed axis"; + TVM_FFI_ICHECK_GE(new_axis, 0) << "Failed to find transformed axis"; ObjectPtr new_attrs = ffi::make_object(*attrs); new_attrs->axis = Integer(new_axis); @@ -2153,9 +2159,9 @@ StructInfo InferStructInfoGatherElements(const Call& call, const BlockBuilder& c InferLayoutOutput InferLayoutGatherElements( const Call& call, const ffi::Map>& desired_layouts, const VarLayoutMap& var_layout_map) { - ICHECK(NoDesiredLayout(call, desired_layouts)); + TVM_FFI_ICHECK(NoDesiredLayout(call, desired_layouts)); const auto* attrs = call->attrs.as(); - ICHECK(attrs) << "Invalid Call"; + TVM_FFI_ICHECK(attrs) << "Invalid Call"; LayoutDecision data_layout = GetLayoutDecision(var_layout_map, call->args[0]); LayoutDecision indices_layout = GetLayoutDecision(var_layout_map, call->args[1]); @@ -2172,8 +2178,8 @@ InferLayoutOutput InferLayoutGatherElements( if (layout->layout.ndim() != layout->layout.ndim_primal()) { const auto* tensor_sinfo = GetStructInfoAs(call->args[0]); - ICHECK(tensor_sinfo != nullptr) << "Invalid Call"; - ICHECK(!tensor_sinfo->IsUnknownNdim()) << "Only support static ndim for now"; + TVM_FFI_ICHECK(tensor_sinfo != nullptr) << "Invalid Call"; + TVM_FFI_ICHECK(!tensor_sinfo->IsUnknownNdim()) << "Only support static ndim for now"; int ndim = tensor_sinfo->ndim; layout = LayoutDecision(InitialLayout(ndim)); } @@ -2223,7 +2229,7 @@ StructInfo InferStructInfoGatherND(const Call& call, const BlockBuilder& ctx) { << "GatherND requires the input indices to be a Tensor. However, the given one is " << call->args[1]->struct_info_->GetTypeKey()); } - ICHECK_GE(attrs->batch_dims.IntValue(), 0); + TVM_FFI_ICHECK_GE(attrs->batch_dims.IntValue(), 0); int batch_dims = attrs->batch_dims.IntValue(); int input_dims = data_sinfo->ndim; if (!indices_sinfo->IsUnknownDtype() && indices_sinfo->dtype != DataType::Int(64)) { @@ -2276,7 +2282,7 @@ StructInfo InferStructInfoGatherND(const Call& call, const BlockBuilder& ctx) { for (int i = batch_dims + l; i < input_dims; ++i) { out_shape.push_back(data_shape->values[i]); } - ICHECK_EQ(out_shape.size(), output_ndim); + TVM_FFI_ICHECK_EQ(out_shape.size(), output_ndim); return TensorStructInfo(ShapeExpr(out_shape), data_sinfo->dtype, data_sinfo->vdevice); } @@ -2650,9 +2656,9 @@ StructInfo InferStructInfoScatterElements(const Call& call, const BlockBuilder& InferLayoutOutput InferLayoutScatterElements( const Call& call, const ffi::Map>& desired_layouts, const VarLayoutMap& var_layout_map) { - ICHECK(NoDesiredLayout(call, desired_layouts)); + TVM_FFI_ICHECK(NoDesiredLayout(call, desired_layouts)); const auto* attrs = call->attrs.as(); - ICHECK(attrs) << "Invalid Call"; + TVM_FFI_ICHECK(attrs) << "Invalid Call"; LayoutDecision data_layout = GetLayoutDecision(var_layout_map, call->args[0]); LayoutDecision indices_layout = GetLayoutDecision(var_layout_map, call->args[1]); @@ -2665,8 +2671,8 @@ InferLayoutOutput InferLayoutScatterElements( if (layout->layout.ndim() != layout->layout.ndim_primal()) { const auto* tensor_sinfo = GetStructInfoAs(call->args[0]); - ICHECK(tensor_sinfo != nullptr) << "Invalid Call"; - ICHECK(!tensor_sinfo->IsUnknownNdim()) << "Only support static ndim for now"; + TVM_FFI_ICHECK(tensor_sinfo != nullptr) << "Invalid Call"; + TVM_FFI_ICHECK(!tensor_sinfo->IsUnknownNdim()) << "Only support static ndim for now"; int ndim = tensor_sinfo->ndim; layout = LayoutDecision(InitialLayout(ndim)); } @@ -2703,7 +2709,7 @@ TVM_FFI_STATIC_INIT_BLOCK() { StructInfo InferStructInfoScatterND(const Call& call, const BlockBuilder& ctx) { // `call->args` contains: [data, indices, updates] arith::Analyzer* analyzer = ctx->GetAnalyzer(); - ICHECK_EQ(call->args.size(), 3); + TVM_FFI_ICHECK_EQ(call->args.size(), 3); const auto* data_sinfo = GetStructInfoAs(call->args[0]); const auto* indices_sinfo = GetStructInfoAs(call->args[1]); const auto* updates_sinfo = GetStructInfoAs(call->args[2]); @@ -2817,7 +2823,7 @@ StructInfo InferStructInfoScatterND(const Call& call, const BlockBuilder& ctx) { InferLayoutOutput InferLayoutScatterND( const Call& call, const ffi::Map>& desired_layouts, const VarLayoutMap& var_layout_map) { - ICHECK(NoDesiredLayout(call, desired_layouts)); + TVM_FFI_ICHECK(NoDesiredLayout(call, desired_layouts)); LayoutDecision data_layout = GetLayoutDecision(var_layout_map, call->args[0]); LayoutDecision indices_layout = GetLayoutDecision(var_layout_map, call->args[1]); @@ -2825,10 +2831,10 @@ InferLayoutOutput InferLayoutScatterND( const auto* data_sinfo = GetStructInfoAs(call->args[0]); const auto* updates_sinfo = GetStructInfoAs(call->args[2]); - ICHECK(data_sinfo != nullptr) << "Invalid Call"; - ICHECK(updates_sinfo != nullptr) << "Invalid Call"; - ICHECK(!data_sinfo->IsUnknownNdim()) << "Only support static ndim for now"; - ICHECK(!updates_sinfo->IsUnknownNdim()) << "Only support static ndim for now"; + TVM_FFI_ICHECK(data_sinfo != nullptr) << "Invalid Call"; + TVM_FFI_ICHECK(updates_sinfo != nullptr) << "Invalid Call"; + TVM_FFI_ICHECK(!data_sinfo->IsUnknownNdim()) << "Only support static ndim for now"; + TVM_FFI_ICHECK(!updates_sinfo->IsUnknownNdim()) << "Only support static ndim for now"; LayoutDecision layout = data_layout; LayoutDecision out_updates_layout = updates_layout; @@ -2972,9 +2978,9 @@ StructInfo InferStructInfoSliceScatter(const Call& call, const BlockBuilder& ctx const auto* src_shape_node = src_sinfo->shape.as(); if (data_shape_node && src_shape_node && !src_sinfo->IsUnknownNdim()) { - ICHECK_EQ(data_shape_node->values.size(), static_cast(ndim)) + TVM_FFI_ICHECK_EQ(data_shape_node->values.size(), static_cast(ndim)) << "Internal error: data_shape_node rank mismatch with data_sinfo->ndim for call " << call; - ICHECK_EQ(src_shape_node->values.size(), static_cast(src_sinfo->ndim)) + TVM_FFI_ICHECK_EQ(src_shape_node->values.size(), static_cast(src_sinfo->ndim)) << "Internal error: src_shape_node rank mismatch with src_sinfo->ndim for call " << call; PrimExpr num_elem = tvm::floordiv((stop_val - start_val + step_val - PrimExpr(1)), step_val); @@ -3030,10 +3036,11 @@ Expr one_hot(Expr indices, PrimValue on_value, PrimValue off_value, int depth, i // Check if on_value and off_value have the same dtype DataType on_dtype = on_value->value->dtype; DataType off_dtype = off_value->value->dtype; - ICHECK(on_dtype == off_dtype) << "one_hot: on_value and off_value must have the same dtype, " - << "but got " << on_dtype << " and " << off_dtype; + TVM_FFI_ICHECK(on_dtype == off_dtype) + << "one_hot: on_value and off_value must have the same dtype, " + << "but got " << on_dtype << " and " << off_dtype; - ICHECK(depth > 0) << "one_hot: depth must be positive, but got " << depth; + TVM_FFI_ICHECK(depth > 0) << "one_hot: depth must be positive, but got " << depth; static const Op& op = Op::Get("relax.one_hot"); return Call(op, {indices, on_value, off_value}, Attrs(attrs), {}); @@ -3050,7 +3057,7 @@ StructInfo InferStructInfoOneHot(const Call& call, const BlockBuilder& ctx) { PrimValue on_value = Downcast(call->args[1]); PrimValue off_value = Downcast(call->args[2]); // Check if on_value and off_value have the same dtype - ICHECK(on_value->value->dtype == off_value->value->dtype) + TVM_FFI_ICHECK(on_value->value->dtype == off_value->value->dtype) << "one_hot: on_value and off_value must have the same dtype, " << "but got " << on_value->value->dtype << " and " << off_value->value->dtype; DataType dtype = on_value->value->dtype; @@ -3079,7 +3086,7 @@ StructInfo InferStructInfoOneHot(const Call& call, const BlockBuilder& ctx) { if (axis < 0) { axis += output_shape.size() + 1; } - ICHECK(0 <= axis && axis <= static_cast(output_shape.size())) + TVM_FFI_ICHECK(0 <= axis && axis <= static_cast(output_shape.size())) << "one_hot: axis must be in the range of [0, " << output_shape.size() << "], " << "but got " << axis; output_shape.insert(output_shape.begin() + axis, attrs->depth); diff --git a/src/relax/op/tensor/search.cc b/src/relax/op/tensor/search.cc index 0cd221d53d1c..0048dd7a347b 100644 --- a/src/relax/op/tensor/search.cc +++ b/src/relax/op/tensor/search.cc @@ -156,7 +156,7 @@ StructInfo InferStructInfoWhere(const Call& call, const BlockBuilder& ctx) { } return TensorStructInfo(output_dtype, output_ndim); } - ICHECK_EQ(static_cast(broadcasted_shape.value().size()), output_ndim); + TVM_FFI_ICHECK_EQ(static_cast(broadcasted_shape.value().size()), output_ndim); if (vdev.defined()) { return TensorStructInfo(ShapeExpr(broadcasted_shape.value()), output_dtype, vdev); } @@ -206,7 +206,7 @@ StructInfo InferStructInfoArgmaxArgmin(const Call& call, const BlockBuilder& ctx out_ndim = kUnknownNDim; } else { out_ndim = data_sinfo->ndim - 1; - ICHECK_GE(out_ndim, 0); + TVM_FFI_ICHECK_GE(out_ndim, 0); } DataType out_dtype = DataType::Int(64); @@ -243,7 +243,7 @@ StructInfo InferStructInfoArgmaxArgmin(const Call& call, const BlockBuilder& ctx out_shape.push_back(IntImm(out_dtype, /*value=*/1)); } } - ICHECK_EQ(static_cast(out_shape.size()), out_ndim); + TVM_FFI_ICHECK_EQ(static_cast(out_shape.size()), out_ndim); return TensorStructInfo(ShapeExpr(out_shape), out_dtype, data_sinfo->vdevice); } diff --git a/src/relax/op/tensor/set.cc b/src/relax/op/tensor/set.cc index c3ee496794da..e13234054e33 100644 --- a/src/relax/op/tensor/set.cc +++ b/src/relax/op/tensor/set.cc @@ -67,16 +67,16 @@ StructInfo InferStructInfoUnique(const Call& call, const BlockBuilder& ctx) { NormalizeAxis(call, ctx, data_sinfo->ndim, axis_int->value); } } - ICHECK(call->args[2]->IsInstance()); - ICHECK(call->args[3]->IsInstance()); - ICHECK(call->args[4]->IsInstance()); + TVM_FFI_ICHECK(call->args[2]->IsInstance()); + TVM_FFI_ICHECK(call->args[3]->IsInstance()); + TVM_FFI_ICHECK(call->args[4]->IsInstance()); return_index = Downcast(call->args[2]); return_inverse = Downcast(call->args[3]); return_counts = Downcast(call->args[4]); auto f_convert_to_int64 = [](const PrimExpr& value) { - CHECK(value->IsInstance()) + TVM_FFI_ICHECK(value->IsInstance()) << value << " expects to be IntImm, but gets " << value->GetTypeKey(); const auto* val_node = value.as(); auto val_imm = ffi::GetRef(val_node); diff --git a/src/relax/op/tensor/sorting.cc b/src/relax/op/tensor/sorting.cc index db0bd8a8c700..01834c9266a6 100644 --- a/src/relax/op/tensor/sorting.cc +++ b/src/relax/op/tensor/sorting.cc @@ -155,7 +155,7 @@ StructInfo InferStructInfoTopK(const Call& call, const BlockBuilder& ctx) { } else if (ret_type == "indices") { return output_sinfos[1]; } - LOG(FATAL) << "Unsupported ret type: " << ret_type; + TVM_FFI_THROW(InternalError) << "Unsupported ret type: " << ret_type; TVM_FFI_UNREACHABLE(); } diff --git a/src/relax/op/tensor/statistical.cc b/src/relax/op/tensor/statistical.cc index 771f6ffb133f..e75aaed46d6a 100644 --- a/src/relax/op/tensor/statistical.cc +++ b/src/relax/op/tensor/statistical.cc @@ -55,7 +55,7 @@ StructInfo InferStructInfoStatistical(const Call& call, const BlockBuilder& ctx) out_ndim = kUnknownNDim; } else { out_ndim = data_sinfo->ndim - axes.size(); - ICHECK_GE(out_ndim, 0); + TVM_FFI_ICHECK_GE(out_ndim, 0); } // The inference rule for reduction operator output shapes: @@ -87,20 +87,20 @@ StructInfo InferStructInfoStatistical(const Call& call, const BlockBuilder& ctx) out_shape.push_back(IntImm(DataType::Int(64), /*value=*/1)); } } - ICHECK_EQ(static_cast(out_shape.size()), out_ndim); + TVM_FFI_ICHECK_EQ(static_cast(out_shape.size()), out_ndim); return TensorStructInfo(ShapeExpr(out_shape), data_sinfo->dtype, data_sinfo->vdevice); } InferLayoutOutput InferLayoutStatistical( const Call& call, const ffi::Map>& desired_layouts, const VarLayoutMap& var_layout_map) { - ICHECK(NoDesiredLayout(call, desired_layouts)); + TVM_FFI_ICHECK(NoDesiredLayout(call, desired_layouts)); const auto* attrs = call->attrs.as(); - ICHECK(attrs != nullptr) << "Invalid Call"; + TVM_FFI_ICHECK(attrs != nullptr) << "Invalid Call"; const auto* tensor_sinfo = GetStructInfoAs(call->args[0]); - ICHECK(tensor_sinfo != nullptr) << "Invalid Call"; - ICHECK(!tensor_sinfo->IsUnknownNdim()) << "Only support known ndim"; + TVM_FFI_ICHECK(tensor_sinfo != nullptr) << "Invalid Call"; + TVM_FFI_ICHECK(!tensor_sinfo->IsUnknownNdim()) << "Only support known ndim"; int ndim = tensor_sinfo->ndim; ffi::Array axis; @@ -198,7 +198,7 @@ StructInfo InferStructInfoStatisticalExtension(const Call& call, const BlockBuil out_ndim = kUnknownNDim; } else { out_ndim = data_sinfo->ndim - axes.size(); - ICHECK_GE(out_ndim, 0); + TVM_FFI_ICHECK_GE(out_ndim, 0); } // The inference rule for median operator output shapes: @@ -232,7 +232,7 @@ StructInfo InferStructInfoStatisticalExtension(const Call& call, const BlockBuil out_shape.push_back(IntImm(DataType::Int(64), /*value=*/1)); } } - ICHECK_EQ(static_cast(out_shape.size()), out_ndim); + TVM_FFI_ICHECK_EQ(static_cast(out_shape.size()), out_ndim); if (!attrs->axis.defined() || axes.size() > 1) return TensorStructInfo(ShapeExpr(out_shape), data_sinfo->dtype, data_sinfo->vdevice); diff --git a/src/relax/op/tensor/ternary.cc b/src/relax/op/tensor/ternary.cc index a38585cb507a..6b885e420802 100644 --- a/src/relax/op/tensor/ternary.cc +++ b/src/relax/op/tensor/ternary.cc @@ -118,7 +118,7 @@ StructInfo InferStructInfoEwiseFMA(const Call& call, const BlockBuilder& ctx) { InferLayoutOutput InferLayoutEwiseFMA( const Call& call, const ffi::Map>& desired_layouts, const VarLayoutMap& var_layout_map) { - ICHECK(NoDesiredLayout(call, desired_layouts)); + TVM_FFI_ICHECK(NoDesiredLayout(call, desired_layouts)); LayoutDecision layout0 = GetLayoutDecision(var_layout_map, call->args[0]); LayoutDecision layout1 = GetLayoutDecision(var_layout_map, call->args[1]); diff --git a/src/relax/op/tensor/unary.cc b/src/relax/op/tensor/unary.cc index 50f5ce2bf35f..7c88633dc280 100644 --- a/src/relax/op/tensor/unary.cc +++ b/src/relax/op/tensor/unary.cc @@ -77,10 +77,10 @@ TVM_REGISTER_OP("relax.clip") .set_attr("FPurity", Bool(true)); Expr clip(Expr x, Expr min, Expr max) { - CHECK(min->IsInstance()) + TVM_FFI_ICHECK(min->IsInstance()) << "The argument `min` of relax.clip is expected to be a PrimValue, but got " << min->GetTypeKey(); - CHECK(max->IsInstance()) + TVM_FFI_ICHECK(max->IsInstance()) << "The argument `max` of relax.clip is expected to be a PrimValue, but got " << max->GetTypeKey(); static const Op& op = Op::Get("relax.clip"); diff --git a/src/relax/op/vision/nms.cc b/src/relax/op/vision/nms.cc index 2a1ad8f40aa4..294cd40c4515 100644 --- a/src/relax/op/vision/nms.cc +++ b/src/relax/op/vision/nms.cc @@ -60,10 +60,10 @@ StructInfo InferStructInfoAllClassNMS(const Call& call, const BlockBuilder& ctx) tvm::ffi::Array input_sinfo = GetInputTensorStructInfo(call, ctx); const auto boxes_sinfo = input_sinfo[0]; const auto scores_sinfo = input_sinfo[1]; - ICHECK(!boxes_sinfo->IsUnknownNdim()) << "Only support known ndim"; - ICHECK(!scores_sinfo->IsUnknownNdim()) << "Only support known ndim"; - ICHECK_EQ(boxes_sinfo->ndim, 3) << "AllClassNMS input boxes should be 3-D."; - ICHECK_EQ(scores_sinfo->ndim, 3) << "AllClassNMS input scores count should be 3-D."; + TVM_FFI_ICHECK(!boxes_sinfo->IsUnknownNdim()) << "Only support known ndim"; + TVM_FFI_ICHECK(!scores_sinfo->IsUnknownNdim()) << "Only support known ndim"; + TVM_FFI_ICHECK_EQ(boxes_sinfo->ndim, 3) << "AllClassNMS input boxes should be 3-D."; + TVM_FFI_ICHECK_EQ(scores_sinfo->ndim, 3) << "AllClassNMS input scores count should be 3-D."; const auto batch = boxes_sinfo->shape.as()->values[0]; const auto num_classes = scores_sinfo->shape.as()->values[1]; diff --git a/src/relax/training/utils.cc b/src/relax/training/utils.cc index 26290775fe64..614f20dba7f8 100644 --- a/src/relax/training/utils.cc +++ b/src/relax/training/utils.cc @@ -42,7 +42,7 @@ class AppendLossMutator : private ExprMutator { static IRModule Transform(IRModule mod, ffi::String func_name, Function loss_function, int num_backbone_outputs, ffi::Optional new_func_name) { auto* old_func = mod->Lookup(func_name).as(); - CHECK(old_func) << func_name << "is not a Relax Function"; + TVM_FFI_ICHECK(old_func) << func_name << "is not a Relax Function"; // functions should be copied to satisfy the well-formed check Function new_func = CopyWithNewVars(ffi::GetRef(old_func)); @@ -82,7 +82,8 @@ class AppendLossMutator : private ExprMutator { } Expr VisitExpr_(const SeqExprNode* seq_expr) final { - CHECK(seq_expr->blocks.size() == 1 && seq_expr->blocks[0]->IsInstance()) + TVM_FFI_ICHECK(seq_expr->blocks.size() == 1 && + seq_expr->blocks[0]->IsInstance()) << "Backbone should have only one DataflowBlock"; auto new_blocks = ffi::Array({this->VisitBindingBlock(seq_expr->blocks[0])}); @@ -115,10 +116,11 @@ class AppendLossMutator : private ExprMutator { /*! \brief Checks the loss function have only one DataflowBlock, and returns a scalar Var. */ void CheckLossBody() { - CHECK(loss_body_->blocks.size() == 1 && loss_body_->blocks[0]->IsInstance()) + TVM_FFI_ICHECK(loss_body_->blocks.size() == 1 && + loss_body_->blocks[0]->IsInstance()) << "The loss function should have only one DataflowBlock"; auto var_node = loss_body_->body.as(); - CHECK(var_node && IsScalarTensor(ffi::GetRef(var_node))) + TVM_FFI_ICHECK(var_node && IsScalarTensor(ffi::GetRef(var_node))) << "The loss function must return a scalar(0-dim Tensor) Var"; } @@ -132,11 +134,13 @@ class AppendLossMutator : private ExprMutator { } else if (auto* tuple = backbone_return.as()) { for (auto i : tuple->fields) { auto var = i.as(); - CHECK(var) << "The return value of the backbone should be either a Var or a Tuple of Vars"; + TVM_FFI_ICHECK(var) + << "The return value of the backbone should be either a Var or a Tuple of Vars"; backbone_return_arr_.push_back(ffi::GetRef(var)); } } else { - LOG(FATAL) << "The return value of the backbone should be either a Var or a Tuple of Vars"; + TVM_FFI_THROW(InternalError) + << "The return value of the backbone should be either a Var or a Tuple of Vars"; } } @@ -147,7 +151,7 @@ class AppendLossMutator : private ExprMutator { */ void CheckAndRemapLossParams(const ffi::Array& loss_func_params) { static StructuralEqual checker; - CHECK(static_cast(loss_func_params.size()) >= num_backbone_outputs_) + TVM_FFI_ICHECK(static_cast(loss_func_params.size()) >= num_backbone_outputs_) << "The number of parameters of the loss function is " << loss_func_params.size() << ", which is less than the given num_backbone_outputs " << num_backbone_outputs_; for (int i = 0; i < num_backbone_outputs_; ++i) { @@ -156,7 +160,7 @@ class AppendLossMutator : private ExprMutator { auto loss_param_sinfo = GetStructInfo(loss_param); auto backbone_ret_sinfo = GetStructInfo(backbone_ret); - CHECK(checker(backbone_ret_sinfo, loss_param_sinfo)) + TVM_FFI_ICHECK(checker(backbone_ret_sinfo, loss_param_sinfo)) << "The struct info of the " << i << "-th return value of backbone function is: " << backbone_ret_sinfo << " while the corresponding struct info of parameter of loss function is " @@ -177,7 +181,7 @@ class AppendLossMutator : private ExprMutator { * Because such Vars are no longer the outputs of the new function. */ void CheckAndRemapBackboneReturn() { - CHECK(static_cast(backbone_return_arr_.size()) >= num_backbone_outputs_) + TVM_FFI_ICHECK(static_cast(backbone_return_arr_.size()) >= num_backbone_outputs_) << "The number of return values of the backbone function is " << backbone_return_arr_.size() << ", which is less than the given num_backbone_outputs " << num_backbone_outputs_; std::unordered_set other_outputs_var( diff --git a/src/relax/transform/adjust_matmul_order.cc b/src/relax/transform/adjust_matmul_order.cc index 889272019174..fcbd30cd3c4a 100644 --- a/src/relax/transform/adjust_matmul_order.cc +++ b/src/relax/transform/adjust_matmul_order.cc @@ -143,15 +143,15 @@ std::tuple)>> if (matches.count(pat_permuted_matmul_on_lhs)) { expr_a = permute_dims(expr_a, std::nullopt); expr_b = permute_dims(expr_b, std::nullopt); - CHECK_EQ(shape_a.size(), 2); - CHECK_EQ(shape_b.size(), 2); + TVM_FFI_ICHECK_EQ(shape_a.size(), 2); + TVM_FFI_ICHECK_EQ(shape_b.size(), 2); shape_a = {shape_a[1], shape_a[0]}; shape_b = {shape_b[1], shape_b[0]}; } else if (matches.count(pat_permuted_matmul_on_rhs)) { expr_b = permute_dims(expr_b, std::nullopt); expr_c = permute_dims(expr_c, std::nullopt); - CHECK_EQ(shape_b.size(), 2); - CHECK_EQ(shape_c.size(), 2); + TVM_FFI_ICHECK_EQ(shape_b.size(), 2); + TVM_FFI_ICHECK_EQ(shape_c.size(), 2); shape_b = {shape_b[1], shape_b[0]}; shape_c = {shape_c[1], shape_c[0]}; } @@ -182,9 +182,9 @@ std::tuple)>> } else if (matches.count(pat_matmul_on_rhs)) { shape_b = {IntImm(shape_b[0].dtype(), 1), shape_b[0]}; } else { - LOG(FATAL) << "InternalError: " - << "OrPattern " << pat << " matched, but neither " << pat_matmul_on_lhs - << " nor " << pat_matmul_on_rhs << " matched"; + TVM_FFI_THROW(InternalError) + << "OrPattern " << pat << " matched, but neither " << pat_matmul_on_lhs << " nor " + << pat_matmul_on_rhs << " matched"; } } if (shape_c.size() == 1) { diff --git a/src/relax/transform/allocate_workspace.cc b/src/relax/transform/allocate_workspace.cc index 4e71e0c3eb43..64e52f30e5a7 100644 --- a/src/relax/transform/allocate_workspace.cc +++ b/src/relax/transform/allocate_workspace.cc @@ -85,7 +85,7 @@ class ExternFunctionRewriter : ExprMutator { // Append the workspace argument to this call. The callee should have been updated to accept // a workspace as the last parameter. auto new_args = call_node->args; - ICHECK(workspace_var_param_.defined()); + TVM_FFI_ICHECK(workspace_var_param_.defined()); new_args.push_back(workspace_var_param_); return Call(new_op, new_args, call_node->attrs, call_node->sinfo_args, call_node->span); } @@ -171,7 +171,7 @@ class WorkspaceProvider : ExprMutator { if (auto gv = new_op.as()) { if (new_gvars_.count(gv.value())) { auto new_args = call_node->args; - ICHECK(workspace_var_main_.defined()); + TVM_FFI_ICHECK(workspace_var_main_.defined()); new_args.push_back(workspace_var_main_); return Call(new_op, new_args, call_node->attrs, call_node->sinfo_args, call_node->span); } diff --git a/src/relax/transform/alter_op_impl.cc b/src/relax/transform/alter_op_impl.cc index a612ef83bde0..f066bd02daec 100644 --- a/src/relax/transform/alter_op_impl.cc +++ b/src/relax/transform/alter_op_impl.cc @@ -49,7 +49,7 @@ static ffi::Array ConstructRangeFromShape(const ffi::Array& sha static ffi::Array GetShapeFromTensorStructInfo(const TensorStructInfo& tensor_sinfo) { auto shape = tensor_sinfo->GetShape(); - ICHECK(shape.defined()); + TVM_FFI_ICHECK(shape.defined()); return shape.value(); } @@ -118,7 +118,7 @@ class AlterOpImplMutator : public ExprMutator { if (call->args[0].as()) return call; // Get operator name from callee - ICHECK(call->args[0]->IsInstance()); + TVM_FFI_ICHECK(call->args[0]->IsInstance()); const tir::PrimFunc& old_func = Downcast(mod_->Lookup(Downcast(call->args[0]))); ffi::Optional maybe_op_kind = old_func->attrs.GetAttr(kOperatorName); @@ -139,10 +139,11 @@ class AlterOpImplMutator : public ExprMutator { if (op_buffer_input_axis_separators__.count(op_kind)) input_axis_separators = op_buffer_input_axis_separators__[op_kind]; - ICHECK(buffer_transforms.empty() || buffer_transforms.size() == replacement_func->params.size()) + TVM_FFI_ICHECK(buffer_transforms.empty() || + buffer_transforms.size() == replacement_func->params.size()) << "Either the i/o buffers do not require any transformations or transformations for each " "buffer is provided."; - ICHECK_EQ(old_func->params.size(), replacement_func->params.size()) + TVM_FFI_ICHECK_EQ(old_func->params.size(), replacement_func->params.size()) << "Number of parameters of old and replacement PrimFunc must match"; GlobalVar replacement_gv = GetOrCreateGlobalVarForFunc(replacement_func, op_kind); @@ -151,7 +152,8 @@ class AlterOpImplMutator : public ExprMutator { Tuple updated_inputs = UpdateInputs(call_tir_inputs_tuple, buffer_transforms, axis_separators, input_axis_separators); - ICHECK_EQ(call->sinfo_args.size(), 1) << "call_tir sinfo_args.size() is expected to be 1"; + TVM_FFI_ICHECK_EQ(call->sinfo_args.size(), 1) + << "call_tir sinfo_args.size() is expected to be 1"; StructInfo updated_ret_sinfo = UpdateStructInfo(call->sinfo_args[0], buffer_transforms); auto updated_call = builder_->Normalize( Call(call_tir_op_, {replacement_gv, updated_inputs}, call->attrs, {updated_ret_sinfo})); @@ -165,13 +167,13 @@ class AlterOpImplMutator : public ExprMutator { if (const auto* tensor_sinfo = output_sinfo.as()) return {ffi::GetRef(tensor_sinfo)}; const auto* tuple_sinfo = output_sinfo.as(); - ICHECK(tuple_sinfo); + TVM_FFI_ICHECK(tuple_sinfo); ffi::Array arr_tensor_sinfo; arr_tensor_sinfo.reserve(tuple_sinfo->fields.size()); for (const auto& sinfo : tuple_sinfo->fields) { const auto* tensor_sinfo = sinfo.as(); - ICHECK(tensor_sinfo) << "Nested tuples in output of call_tir is not supported yet"; + TVM_FFI_ICHECK(tensor_sinfo) << "Nested tuples in output of call_tir is not supported yet"; arr_tensor_sinfo.push_back(ffi::GetRef(tensor_sinfo)); } return arr_tensor_sinfo; @@ -324,7 +326,7 @@ class AlterOpImplMutator : public ExprMutator { return UpdateStructInfo(Downcast(out_sinfo), buffer_transforms[buffer_transforms.size() - 1]); - ICHECK(out_sinfo->IsInstance()) + TVM_FFI_ICHECK(out_sinfo->IsInstance()) << "Expect output struct info of call_tir to be either TupleStructInfo or " "TensorStructInfo, but got " << out_sinfo; @@ -334,7 +336,7 @@ class AlterOpImplMutator : public ExprMutator { size_t first_output_index = buffer_transforms.size() - tuple_sinfo->fields.size(); size_t i = 0; for (const auto& si : tuple_sinfo->fields) { - ICHECK(si->IsInstance()) + TVM_FFI_ICHECK(si->IsInstance()) << "Fields of TupleStructInfo must be TensorStructInfo for call_tir " "output structinfo, but got " << si; diff --git a/src/relax/transform/attach_attr_layout_free_buffers.cc b/src/relax/transform/attach_attr_layout_free_buffers.cc index 129294bf6613..e1e2e0ca26b8 100644 --- a/src/relax/transform/attach_attr_layout_free_buffers.cc +++ b/src/relax/transform/attach_attr_layout_free_buffers.cc @@ -50,7 +50,8 @@ class AttrAttacher : public ExprMutator { using ExprMutator::VisitExpr_; Expr VisitExpr_(const FunctionNode* op) final { if (auto opt_num_input = op->attrs.GetAttr(attr::kNumInput)) { - ICHECK(layout_free_exprs_.empty()) << "meet a non-global function with num_input attr"; + TVM_FFI_ICHECK(layout_free_exprs_.empty()) + << "meet a non-global function with num_input attr"; size_t num_input = opt_num_input.value()->value; for (size_t i = num_input; i < op->params.size(); i++) { layout_free_exprs_.insert(op->params[i].get()); diff --git a/src/relax/transform/bind_params.cc b/src/relax/transform/bind_params.cc index 4ad9b3ab5051..ea93a77ee7e1 100644 --- a/src/relax/transform/bind_params.cc +++ b/src/relax/transform/bind_params.cc @@ -34,47 +34,47 @@ namespace relax { void MatchSymbolicVar(const Expr& arg, const Expr& constant, ffi::Map* symbolic_var_map, arith::Analyzer* analyzer_) { auto opt_arg_sinfo = MatchStructInfo(arg); - CHECK(opt_arg_sinfo) + TVM_FFI_ICHECK(opt_arg_sinfo) << "The struct info of the bound parameter is expected to be TensorStructInfo, but got: " << GetStructInfo(arg); auto opt_const_sinfo = MatchStructInfo(constant); // As the constant is generated by internal codes, we use ICHECK here. - ICHECK(opt_const_sinfo) + TVM_FFI_ICHECK(opt_const_sinfo) << "The struct info of the bound weight is expected to be TensorStructInfo, but got: " << GetStructInfo(constant); TensorStructInfo arg_sinfo = opt_arg_sinfo.value(); TensorStructInfo const_sinfo = opt_const_sinfo.value(); - ICHECK(!const_sinfo->IsUnknownDtype()); - ICHECK(!const_sinfo->IsUnknownNdim()); - ICHECK(const_sinfo->shape.defined()); + TVM_FFI_ICHECK(!const_sinfo->IsUnknownDtype()); + TVM_FFI_ICHECK(!const_sinfo->IsUnknownNdim()); + TVM_FFI_ICHECK(const_sinfo->shape.defined()); // dtype mismatch if (!arg_sinfo->IsUnknownDtype() && arg_sinfo->dtype != const_sinfo->dtype) { - LOG(FATAL) << "The dtype of the bound parameter is expected to be " << arg_sinfo->dtype - << ", but got: " << const_sinfo->dtype; + TVM_FFI_THROW(InternalError) << "The dtype of the bound parameter is expected to be " + << arg_sinfo->dtype << ", but got: " << const_sinfo->dtype; } // ndim mismatch if (!arg_sinfo->IsUnknownNdim() && arg_sinfo->ndim != const_sinfo->ndim) { - LOG(FATAL) << "The ndim of the bound parameter is expected to be " << arg_sinfo->ndim - << ", but got: " << const_sinfo->ndim; + TVM_FFI_THROW(InternalError) << "The ndim of the bound parameter is expected to be " + << arg_sinfo->ndim << ", but got: " << const_sinfo->ndim; } if (!arg_sinfo->shape.defined()) return; const auto* arg_shape = arg_sinfo->shape.value().as(); const auto* const_shape = const_sinfo->shape.value().as(); - CHECK(arg_shape && const_shape) + TVM_FFI_ICHECK(arg_shape && const_shape) << "The shape of the bound parameter and weight is expected to be ShapeExprNode for now"; for (int i = 0; i < arg_sinfo->ndim; ++i) { const PrimExpr& const_dim = const_shape->values[i]; - ICHECK(tir::is_const_int(const_dim)); + TVM_FFI_ICHECK(tir::is_const_int(const_dim)); if (const auto* shape_var = arg_shape->values[i].as()) { auto it = symbolic_var_map->find(ffi::GetRef(shape_var)); if (it == symbolic_var_map->end()) { symbolic_var_map->Set(ffi::GetRef(shape_var), const_dim); } else { - CHECK(analyzer_->CanProveEqual((*it).second, const_dim)) + TVM_FFI_ICHECK(analyzer_->CanProveEqual((*it).second, const_dim)) << "The shape of the bound parameter is expected to be " << (*it).second << ", but got: " << const_dim; } @@ -84,8 +84,8 @@ void MatchSymbolicVar(const Expr& arg, const Expr& constant, std::tuple, ffi::Map> NormalizeBindings( const Function& func, const ffi::Map& untyped_params) { - ICHECK(func.defined()); - ICHECK(untyped_params.defined()); + TVM_FFI_ICHECK(func.defined()); + TVM_FFI_ICHECK(untyped_params.defined()); // Map from string to the variable(s) with that name. std::unordered_map> string_lookup; @@ -101,28 +101,28 @@ std::tuple, ffi::Map> NormalizeBindings( if (auto opt_str = obj.as()) { std::string str = opt_str.value(); auto it = string_lookup.find(str); - CHECK(it != string_lookup.end()) + TVM_FFI_ICHECK(it != string_lookup.end()) << "Function does not have parameter with name \"" << str << "\". " << "Function parameters are named " << func->params.Map([](const auto& param) { return param->name_hint(); }); - CHECK_EQ(it->second.size(), 1) + TVM_FFI_ICHECK_EQ(it->second.size(), 1) << "Function contains multiple parameters with name \"" << str << "\". " << "The Relax variables " << it->second << " are all named \"" << str << "\""; auto var = it->second[0]; - CHECK(!relax_var_remap.count(var)) + TVM_FFI_ICHECK(!relax_var_remap.count(var)) << "Remap of variable " << var << " was defined multiple times"; return var; } else if (auto opt_var = obj.as()) { auto var = opt_var.value(); - CHECK(!relax_var_remap.count(var)) + TVM_FFI_ICHECK(!relax_var_remap.count(var)) << "Remap of variable " << var << " was defined multiple times"; - CHECK(var_set.count(var.get())) + TVM_FFI_ICHECK(var_set.count(var.get())) << "Function does not use Relax variable " << var << " as a parameter. " << "Function parameters are " << func->params; return var; } else { - LOG(FATAL) + TVM_FFI_THROW(InternalError) << "Expected bound parameter to be a relax::Var, " << " or a string that uniquely identifies a relax::Var param within the function. " << "However, received object " << obj << " of type " << obj.GetTypeKey(); @@ -134,7 +134,8 @@ std::tuple, ffi::Map> NormalizeBindings( } else if (auto opt = obj.as()) { return Constant(opt.value()); } else { - LOG(FATAL) << "Cannot coerce object of type " << obj.GetTypeKey() << " into relax expression"; + TVM_FFI_THROW(InternalError) + << "Cannot coerce object of type " << obj.GetTypeKey() << " into relax expression"; } }; diff --git a/src/relax/transform/bind_symbolic_vars.cc b/src/relax/transform/bind_symbolic_vars.cc index 04a4b0819cda..b7d69186f4b5 100644 --- a/src/relax/transform/bind_symbolic_vars.cc +++ b/src/relax/transform/bind_symbolic_vars.cc @@ -54,28 +54,31 @@ Function FunctionBindSymbolicVars( if (auto opt = key.as()) { ffi::String string_key = opt.value(); auto it = string_lookup.find(string_key); - CHECK(it != string_lookup.end()) + TVM_FFI_ICHECK(it != string_lookup.end()) << "Function does not use symbolic var with name \"" << string_key << "\". " << "Function has symbolic variables " << old_symbolic_vars; - CHECK_EQ(it->second.size(), 1) + TVM_FFI_ICHECK_EQ(it->second.size(), 1) << "Function contains multiple symbolic variables with name \"" << string_key << "\". " << "The TIR variables " << it->second << " are all named \"" << string_key << "\""; auto var = it->second[0]; - CHECK(!var_remap.count(var)) << "Remap of variable " << var << " was defined multiple times"; + TVM_FFI_ICHECK(!var_remap.count(var)) + << "Remap of variable " << var << " was defined multiple times"; var_remap.Set(var, replacement); } else if (auto opt = key.as()) { auto var = opt.value(); - CHECK(!var_remap.count(var)) << "Remap of variable " << var << " was defined multiple times"; - CHECK(symbolic_var_set.count(var.get())) + TVM_FFI_ICHECK(!var_remap.count(var)) + << "Remap of variable " << var << " was defined multiple times"; + TVM_FFI_ICHECK(symbolic_var_set.count(var.get())) << "Function does not use variable " << var << " as a symbolic variable. " << "Function has symbolic variables " << old_symbolic_vars; var_remap.Set(var, replacement); } else { - LOG(FATAL) << "Expected symbolic variable to be a tir::Var or a string name, " - << "but " << key << " was of type " << key.GetTypeKey(); + TVM_FFI_THROW(InternalError) + << "Expected symbolic variable to be a tir::Var or a string name, " + << "but " << key << " was of type " << key.GetTypeKey(); } } @@ -83,7 +86,7 @@ Function FunctionBindSymbolicVars( auto free_symbolic_vars = FreeSymbolicVars(new_func); - CHECK(free_symbolic_vars.empty()) + TVM_FFI_ICHECK(free_symbolic_vars.empty()) << "Resulting function should not have any undefined symbolic variables, " << "but TIR variables " << free_symbolic_vars << " were undefined."; @@ -116,8 +119,9 @@ IRModule ModuleBindSymbolicVars( } else if (auto ptr = key.as()) { used_by_function = vars.count(ptr); } else { - LOG(FATAL) << "Expected symbolic variable to be a tir::Var " - << "or a string name, but " << key << " was of type " << key.GetTypeKey(); + TVM_FFI_THROW(InternalError) + << "Expected symbolic variable to be a tir::Var " + << "or a string name, but " << key << " was of type " << key.GetTypeKey(); } if (used_by_function) { used.insert(key); @@ -140,9 +144,9 @@ IRModule ModuleBindSymbolicVars( unused.push_back(key); } } - CHECK_EQ(unused.size(), 0) << "Binding map contains keys " << unused - << ", which did not correspond to any symbolic variables " - << "in the module."; + TVM_FFI_ICHECK_EQ(unused.size(), 0) << "Binding map contains keys " << unused + << ", which did not correspond to any symbolic variables " + << "in the module."; if (updates->functions.size()) { mod.CopyOnWrite()->Update(updates); diff --git a/src/relax/transform/bundle_model_params.cc b/src/relax/transform/bundle_model_params.cc index 877f3d7dea35..fee595c9f364 100644 --- a/src/relax/transform/bundle_model_params.cc +++ b/src/relax/transform/bundle_model_params.cc @@ -45,8 +45,8 @@ class ModelParamBundler : public ExprMutator { if (!opt_num_input) return func; auto signed_num_input = opt_num_input.value()->value; - ICHECK_GE(signed_num_input, 0); - ICHECK_LE(signed_num_input, func->params.size()) + TVM_FFI_ICHECK_GE(signed_num_input, 0); + TVM_FFI_ICHECK_LE(signed_num_input, func->params.size()) << "Function was declared to have " << signed_num_input << " runtime inputs, " << "but only has " << func->params.size() << " parameters total."; size_t num_input = signed_num_input; diff --git a/src/relax/transform/call_tir_rewrite.cc b/src/relax/transform/call_tir_rewrite.cc index 2c0e515be7ec..48efd50d482c 100644 --- a/src/relax/transform/call_tir_rewrite.cc +++ b/src/relax/transform/call_tir_rewrite.cc @@ -78,7 +78,7 @@ class CallTIRMutator : public ExprMutator { if (const auto& _tensor_sinfo = MatchStructInfo(expr)) { // single output case const TensorStructInfo& tensor_sinfo = _tensor_sinfo.value(); - ICHECK(tensor_sinfo->shape.defined()) + TVM_FFI_ICHECK(tensor_sinfo->shape.defined()) << "the TensorStructInfo shape of call_tir has not populated"; int dev_index = 0; ffi::String scope = "global"; @@ -98,7 +98,7 @@ class CallTIRMutator : public ExprMutator { "alloc")); } else { // if there is only one output, it must be an in-place argument, but check anyway - ICHECK(inplace_attrs->inplace_indices[0].IntValue() != -1) + TVM_FFI_ICHECK(inplace_attrs->inplace_indices[0].IntValue() != -1) << "If calling call_tir_inplace and there is one output, its in-place index must not" " be -1."; outs.push_back( @@ -110,11 +110,11 @@ class CallTIRMutator : public ExprMutator { for (size_t i = 0; i < tuple_sinfo->fields.size(); ++i) { const auto& field = tuple_sinfo->fields[i]; - ICHECK(field->IsInstance()) + TVM_FFI_ICHECK(field->IsInstance()) << "call_tir expects Tuple of TensorStructInfo, but got " << field << " as an element of TupleStructInfo"; const auto& field_tensor = Downcast(field); - ICHECK(field_tensor->shape.defined()) + TVM_FFI_ICHECK(field_tensor->shape.defined()) << "call_tir expects all TensorStructInfo has shape, but got " << field_tensor << " as an element of TupleStructInfo"; @@ -138,9 +138,9 @@ class CallTIRMutator : public ExprMutator { } } } else { - LOG(FATAL) << "TypeError: The struct info of call_tir expects to be TensorStructInfo or " - "TupleStructInfo, but got" - << expr->struct_info_; + TVM_FFI_THROW(TypeError) << "The struct info of call_tir expects to be TensorStructInfo or " + "TupleStructInfo, but got" + << expr->struct_info_; } ffi::Array args; diff --git a/src/relax/transform/canonicalize_bindings.cc b/src/relax/transform/canonicalize_bindings.cc index decbecd3098b..05c86d92630d 100644 --- a/src/relax/transform/canonicalize_bindings.cc +++ b/src/relax/transform/canonicalize_bindings.cc @@ -51,8 +51,7 @@ class SymbolicVarCanonicalizer : public ExprMutator { InferSymbolicVarMap({{binding->var, binding->value}}, builder_->GetAnalyzer()); for (const auto& [tir_var, prim_expr] : tir_var_map) { if (auto it = known_values_.find(tir_var); it != known_values_.end()) { - CHECK(!builder_->GetAnalyzer()->CanProve(it->second.expr != prim_expr)) - << "ValueError: " + TVM_FFI_CHECK(!builder_->GetAnalyzer()->CanProve(it->second.expr != prim_expr), ValueError) << "MatchCast statements must be consistent. " << "However, the definition of Relax variable " << it->second.source->var << " implies that TIR variable " << tir_var << " is " << it->second.expr @@ -250,14 +249,14 @@ class CanonicalizePlanner : public ExprVisitor { } void VisitBindingBlock_(const BindingBlockNode* block) override { - CHECK(!current_block_.defined()) << "Forgetting to unset current block"; + TVM_FFI_ICHECK(!current_block_.defined()) << "Forgetting to unset current block"; current_block_ = ffi::GetRef(block); ExprVisitor::VisitBindingBlock_(block); current_block_ = ffi::Optional(); } void VisitBindingBlock_(const DataflowBlockNode* block) override { - CHECK(!current_block_.defined()) << "Forgetting to unset current block"; + TVM_FFI_ICHECK(!current_block_.defined()) << "Forgetting to unset current block"; current_block_ = ffi::GetRef(block); ExprVisitor::VisitBindingBlock_(block); current_block_ = ffi::Optional(); @@ -354,7 +353,7 @@ class CanonicalizePlanner : public ExprVisitor { } else if (auto match_cast = binding.as()) { return StructuralEqual()(GetStructInfo(binding->var), GetStructInfo(match_cast->value)); } else { - LOG(FATAL) << "Invalid binding type: " << binding->GetTypeKey(); + TVM_FFI_THROW(InternalError) << "Invalid binding type: " << binding->GetTypeKey(); } }(); @@ -539,7 +538,7 @@ class BindingCanonicalizer : public ExprMutator { VarBinding(binding->var, candidates.at(Downcast(var_binding->value))); new_bindings.push_back(new_binding); } else { - CHECK(false) << "Invalid binding"; // never happens + TVM_FFI_ICHECK(false) << "Invalid binding"; // never happens } } else { new_bindings.push_back(binding); diff --git a/src/relax/transform/combine_parallel_matmul.cc b/src/relax/transform/combine_parallel_matmul.cc index c60864d671c5..a46b5c5b5546 100644 --- a/src/relax/transform/combine_parallel_matmul.cc +++ b/src/relax/transform/combine_parallel_matmul.cc @@ -122,7 +122,7 @@ ffi::TypedFunction(ffi::Map, ffi::Map>& rhs_shapes) { arith::Analyzer ana; for (auto ind : indices) { - ICHECK_EQ(static_cast(rhs_shapes[ind].size()), rhs_dim); + TVM_FFI_ICHECK_EQ(static_cast(rhs_shapes[ind].size()), rhs_dim); // -2 for reduction and concat axes for (size_t i = 0; i < rhs_dim - 2; ++i) { if (!ana.CanProve(rhs_shapes[indices[0]][i] == rhs_shapes[ind][i])) { @@ -223,7 +223,7 @@ ffi::TypedFunction(ffi::Map, ffi::Map(ffi::Map, ffi::Map sections; for (size_t i = 0; i + 1 < splits.size(); i++) { auto width = splits[i].split_size.as(); - ICHECK(width) << "InternalError: " - << "All splits except the last one must have a static shape"; + TVM_FFI_CHECK(width, InternalError) + << "All splits except the last one must have a static shape"; split_index += width->value; sections.push_back(IntImm(DataType::Int(64), split_index)); } diff --git a/src/relax/transform/convert_layout.cc b/src/relax/transform/convert_layout.cc index 27684313de02..c71675bb26dd 100644 --- a/src/relax/transform/convert_layout.cc +++ b/src/relax/transform/convert_layout.cc @@ -114,11 +114,11 @@ class LayoutConvertMutator : public ExprMutator { NLayout from = layouts[0], to = layouts[1]; if (NLayoutEqual()(from, to) || layouts[0].LeafValue()->layout.name() == "") return expr; // If not both from and to are unknown, then none of them can be unknown. - ICHECK(!NLayoutEqual()(from, LayoutDecision::InitUnknownDim()) && - !NLayoutEqual()(to, LayoutDecision::InitUnknownDim())) + TVM_FFI_ICHECK(!NLayoutEqual()(from, LayoutDecision::InitUnknownDim()) && + !NLayoutEqual()(to, LayoutDecision::InitUnknownDim())) << "Cannot convert when exactly one of the layouts is unknown"; const auto* tensor = GetStructInfoAs(expr); - ICHECK(tensor != nullptr) << "Expect a tensor, but got: " << expr; + TVM_FFI_ICHECK(tensor != nullptr) << "Expect a tensor, but got: " << expr; if (from.LeafValue()->layout.ndim() == to.LeafValue()->layout.ndim()) { Layout axes = TransposeLike(InitialLayoutDecision(tensor->ndim)->layout, @@ -149,7 +149,7 @@ class LayoutConvertMutator : public ExprMutator { // contains tensor arguments. The number of tensor arguments in // `args` should match the full extent of `to`. - ICHECK_LE(to.size(), args.size()); + TVM_FFI_ICHECK_LE(to.size(), args.size()); std::vector new_args; for (size_t i = 0; i < args.size(); ++i) { @@ -258,7 +258,7 @@ class LayoutConvertMutator : public ExprMutator { var_layout_map_[binding->var] = res.value()->output_layouts[0]; } else { // Global var (tensor), we rewrite it to initial layout - ICHECK(IsNestedTensor(binding->var)); + TVM_FFI_ICHECK(IsNestedTensor(binding->var)); if (!NLayoutEqual()(res.value()->output_layouts[0], InitialNLayout(binding->var))) { Var new_var = builder_->Emit(cur_call); var_layout_map_[new_var] = res.value()->output_layouts[0]; @@ -310,15 +310,15 @@ class LayoutConvertMutator : public ExprMutator { NLayout from = layouts[0], to = layouts[1]; if (NLayoutEqual()(from, to)) return sinfo; // If not both from and to are unknown, then none of them can be unknown. - ICHECK(!NLayoutEqual()(from, LayoutDecision::InitUnknownDim()) && - !NLayoutEqual()(to, LayoutDecision::InitUnknownDim())) + TVM_FFI_ICHECK(!NLayoutEqual()(from, LayoutDecision::InitUnknownDim()) && + !NLayoutEqual()(to, LayoutDecision::InitUnknownDim())) << "Cannot convert when exactly one of the layouts is unknown"; const TensorStructInfoNode* tsinfo = sinfo.as(); - ICHECK(tsinfo != nullptr) << "We can not set layout for non-tensor struct"; + TVM_FFI_ICHECK(tsinfo != nullptr) << "We can not set layout for non-tensor struct"; if (!tsinfo->shape.defined()) return sinfo; const ShapeExprNode* shape = tsinfo->shape.value().as(); if (shape == nullptr) return sinfo; - ICHECK_EQ(shape->values.size(), to.LeafValue()->layout.ndim()); + TVM_FFI_ICHECK_EQ(shape->values.size(), to.LeafValue()->layout.ndim()); std::vector new_shape; for (size_t i = 0; i < shape->values.size(); ++i) { new_shape.push_back( diff --git a/src/relax/transform/dataflow_inplace.cc b/src/relax/transform/dataflow_inplace.cc index cf6b690ae34a..28a5c59defc2 100644 --- a/src/relax/transform/dataflow_inplace.cc +++ b/src/relax/transform/dataflow_inplace.cc @@ -86,7 +86,7 @@ std::unordered_map> AnalyzeLiveness(const DataflowBlock } else { // this means the var is used later but we encountered its definition now auto last_range = ret[defined_var]; - CHECK_EQ(last_range.first, -1); + TVM_FFI_ICHECK_EQ(last_range.first, -1); std::pair new_range = {i, last_range.second}; ret[defined_var] = new_range; } diff --git a/src/relax/transform/decompose_ops.cc b/src/relax/transform/decompose_ops.cc index 81d4d3881ede..1b6fc77c48d8 100644 --- a/src/relax/transform/decompose_ops.cc +++ b/src/relax/transform/decompose_ops.cc @@ -32,7 +32,7 @@ namespace relax { TensorStructInfo MatchTensorStructInfo(Expr data) { auto _sinfo = MatchStructInfo(data); - ICHECK(_sinfo.defined()) << "Expect data to be a tensor, but get " << GetStructInfo(data); + TVM_FFI_ICHECK(_sinfo.defined()) << "Expect data to be a tensor, but get " << GetStructInfo(data); return _sinfo.value(); } @@ -51,7 +51,7 @@ Expr ExpandToMatchInput(Expr data, int ndim, ffi::Array axes) { Tuple DecomposeBatchNorm(const Call& call) { auto attrs = call->attrs.as(); - ICHECK_NOTNULL(attrs); + TVM_FFI_ICHECK_NOTNULL(attrs); Expr data = call->args[0]; TensorStructInfo sinfo = MatchTensorStructInfo(data); @@ -78,9 +78,9 @@ Tuple DecomposeBatchNorm(const Call& call) { Expr MutateBatchNormForTraining(Call call) { auto attrs = call->attrs.as(); - ICHECK_NOTNULL(attrs); + TVM_FFI_ICHECK_NOTNULL(attrs); - ICHECK_EQ(call->args.size(), 5); + TVM_FFI_ICHECK_EQ(call->args.size(), 5); Expr data = call->args[0]; Expr gamma = call->args[1]; Expr beta = call->args[2]; @@ -113,7 +113,7 @@ Expr MutateBatchNormForTraining(Call call) { Expr DecomposeLayerNorm(const Call& call) { auto attrs = call->attrs.as(); - ICHECK_NOTNULL(attrs); + TVM_FFI_ICHECK_NOTNULL(attrs); Expr data = call->args[0]; TensorStructInfo sinfo = MatchTensorStructInfo(data); @@ -139,10 +139,10 @@ Expr DecomposeLayerNorm(const Call& call) { } Expr TensorToShape(const Call& call_node, const BlockBuilder& builder) { - ICHECK(call_node->struct_info_.defined()); + TVM_FFI_ICHECK(call_node->struct_info_.defined()); Expr expr = call_node->args[0]; const ShapeStructInfoNode* sinfo = GetStructInfoAs(call_node); - ICHECK(sinfo); + TVM_FFI_ICHECK(sinfo); // call builtin function that converts tensor to shape tuple // TODO(@sunggg): Register operator for "vm.builtin.tensor_to_shape" static const Op& call_pure_packed_op = Op::Get("relax.call_pure_packed"); diff --git a/src/relax/transform/eliminate_common_subexpr.cc b/src/relax/transform/eliminate_common_subexpr.cc index e893b5151b52..7e7f069cdd14 100644 --- a/src/relax/transform/eliminate_common_subexpr.cc +++ b/src/relax/transform/eliminate_common_subexpr.cc @@ -114,8 +114,8 @@ class CommonSubexprEliminator : public ExprMutator { } else if (auto match_cast = binding.as()) { return MatchCast(binding->var, bound_value, match_cast->struct_info); } else { - LOG(FATAL) << "Binding must be either VarBinding or MatchCast, " - << "but was " << binding->GetTypeKey(); + TVM_FFI_THROW(InternalError) << "Binding must be either VarBinding or MatchCast, " + << "but was " << binding->GetTypeKey(); } }(); diff --git a/src/relax/transform/expand_matmul_of_sum.cc b/src/relax/transform/expand_matmul_of_sum.cc index 5504c2a59942..201e3309a151 100644 --- a/src/relax/transform/expand_matmul_of_sum.cc +++ b/src/relax/transform/expand_matmul_of_sum.cc @@ -80,8 +80,8 @@ std::tuple)>> if (matches.count(pat_rhs_permute_dims)) { auto call_permute = Downcast(matches[pat_rhs_permute_dims]); auto attrs = call_permute->attrs.as(); - ICHECK(attrs) << "Operator permute_dims should have PermuteDimsAttrs, " - << "but " << call_permute << " has attributes " << call_permute->attrs; + TVM_FFI_ICHECK(attrs) << "Operator permute_dims should have PermuteDimsAttrs, " + << "but " << call_permute << " has attributes " << call_permute->attrs; auto axes = attrs->axes; rhs_a = permute_dims(rhs_a, axes); diff --git a/src/relax/transform/few_shot_tuning.cc b/src/relax/transform/few_shot_tuning.cc index 7f42c6d2ef03..a88b92e8e4e0 100644 --- a/src/relax/transform/few_shot_tuning.cc +++ b/src/relax/transform/few_shot_tuning.cc @@ -33,14 +33,14 @@ tir::PrimFunc FewShotTunePrimFunc(const tir::PrimFunc& prim_func, const Target& tvm::ffi::Function::GetGlobalRequired("s_tir.meta_schedule.builder.get_local_builder"); s_tir::meta_schedule::Builder builder = f_get_local_builder().cast(); - ICHECK(builder.defined()) << "ValueError: The local builder is not defined!"; + TVM_FFI_CHECK(builder.defined(), ValueError) << "The local builder is not defined!"; // fetch a local runner s_tir::meta_schedule::Runner runner{ffi::UnsafeInit()}; if (benchmark) { static const auto f_get_local_runner = tvm::ffi::Function::GetGlobalRequired("s_tir.meta_schedule.runner.get_local_runner"); runner = f_get_local_runner().cast(); - ICHECK(runner.defined()) << "ValueError: The local runner is not defined!"; + TVM_FFI_CHECK(runner.defined(), ValueError) << "The local runner is not defined!"; } // create an IRModule IRModule mod = IRModule(ffi::Map( @@ -85,7 +85,7 @@ tir::PrimFunc FewShotTunePrimFunc(const tir::PrimFunc& prim_func, const Target& } ffi::Array builder_results = builder->Build(builder_inputs); - ICHECK_EQ(builder_results.size(), candidates.value().size()); + TVM_FFI_ICHECK_EQ(builder_results.size(), candidates.value().size()); int idx = 0; bool no_valid = true; // whether there is no valid schedule in this iteration for (const s_tir::meta_schedule::BuilderResult& builder_result : builder_results) { @@ -122,7 +122,7 @@ tir::PrimFunc FewShotTunePrimFunc(const tir::PrimFunc& prim_func, const Target& costs.push_back(sum / runner_result->run_secs.value().size()); } } - ICHECK_EQ(costs.size(), results.size()); + TVM_FFI_ICHECK_EQ(costs.size(), results.size()); } } if (results.size() == 0) { @@ -150,11 +150,12 @@ Pass FewShotTuning(int valid_count, bool benchmark) { auto pass_func = // [=](IRModule m, PassContext pc) { // input check - CHECK(valid_count > 0) << "Valid_count must be positive."; - CHECK(valid_count > 1 || !benchmark) << "Benchmarking requires at least two valid trials."; + TVM_FFI_ICHECK(valid_count > 0) << "Valid_count must be positive."; + TVM_FFI_ICHECK(valid_count > 1 || !benchmark) + << "Benchmarking requires at least two valid trials."; // get the target from context. tvm::Target target = tvm::Target::Current(); - ICHECK(target.defined()) << "Target is not set in current context"; + TVM_FFI_ICHECK(target.defined()) << "Target is not set in current context"; // generate the few shot tuned prim funcs. ffi::Map result; for (const auto& [gv, func] : m->functions) { diff --git a/src/relax/transform/fold_constant.cc b/src/relax/transform/fold_constant.cc index 86653767662a..3a289ebfff49 100644 --- a/src/relax/transform/fold_constant.cc +++ b/src/relax/transform/fold_constant.cc @@ -57,7 +57,7 @@ class ConstantFolder : public ExprMutator { } const auto* shape = tensor_sinfo->shape.as(); - ICHECK(shape != nullptr) << "struct info given by call_tir should have ShapeExpr shape"; + TVM_FFI_ICHECK(shape != nullptr) << "struct info given by call_tir should have ShapeExpr shape"; std::vector shape_values; for (const auto v : shape->values) { @@ -268,12 +268,12 @@ class ConstantFolder : public ExprMutator { // Returns the folded expr if the call is successfully folded to constant, otherwise null. ffi::Optional VisitCallTIR(Call call) { // call_tir needs to have at least two arguments - ICHECK_GE(call->args.size(), 2); + TVM_FFI_ICHECK_GE(call->args.size(), 2); ffi::Optional func = MatchPrimFunc(call->args[0]); - ICHECK(call->args[1].as()) << "call_tir.args[1] must be Tuple"; + TVM_FFI_ICHECK(call->args[1].as()) << "call_tir.args[1] must be Tuple"; ffi::Optional> arr_args = MatchConstArrayArgs(call->args[1].as()->fields); - ICHECK_EQ(call->sinfo_args.size(), 1) << "call_tir should have exactly one sinfo arg"; + TVM_FFI_ICHECK_EQ(call->sinfo_args.size(), 1) << "call_tir should have exactly one sinfo arg"; if (!func || !arr_args) return {}; @@ -361,15 +361,15 @@ class ConstantFolder : public ExprMutator { // Thus, this is a temporary solution until we have a consensus about // how to deal with composite ops. One possibility is we register the // decomposition map for each op in a similar way we do for legalization. - ICHECK_EQ(post_call->args.size(), 1); + TVM_FFI_ICHECK_EQ(post_call->args.size(), 1); Expr arg = post_call->args[0]; if (arg->IsInstance()) { Constant constant = Downcast(arg); runtime::Tensor ndarray = constant->data; - ICHECK_EQ(ndarray->device.device_type, kDLCPU); - ICHECK(ndarray.IsContiguous()); - ICHECK_EQ(ndarray->byte_offset, 0); - ICHECK_EQ(ndarray->ndim, 1); + TVM_FFI_ICHECK_EQ(ndarray->device.device_type, kDLCPU); + TVM_FFI_ICHECK(ndarray.IsContiguous()); + TVM_FFI_ICHECK_EQ(ndarray->byte_offset, 0); + TVM_FFI_ICHECK_EQ(ndarray->ndim, 1); const int64_t* data = static_cast(ndarray->data); int64_t num_elems = ndarray->shape[0]; ffi::Array shape_values; diff --git a/src/relax/transform/fuse_ops.cc b/src/relax/transform/fuse_ops.cc index 561695787de8..45888419b4be 100644 --- a/src/relax/transform/fuse_ops.cc +++ b/src/relax/transform/fuse_ops.cc @@ -130,8 +130,8 @@ class GraphCreator : public ExprVisitor { // post-dfs order and will be set its op pattern. Thus we check whether all these containers // have the same size. size_t n_nodes = creator.graph_.node_map.size(); - ICHECK_EQ(n_nodes, creator.graph_.post_dfs_order.size()); - ICHECK_EQ(n_nodes, creator.initialized_nodes_.size()); + TVM_FFI_ICHECK_EQ(n_nodes, creator.graph_.post_dfs_order.size()); + TVM_FFI_ICHECK_EQ(n_nodes, creator.initialized_nodes_.size()); return creator.graph_; } @@ -189,7 +189,7 @@ class GraphCreator : public ExprVisitor { /********** Non-Leaf Expression Nodes **********/ void VisitCall(const CallNode* call, IndexedForwardGraph::Node* binding_var_node) { - ICHECK_NOTNULL(binding_var_node); + TVM_FFI_ICHECK_NOTNULL(binding_var_node); static const Op& call_tir_op_ = Op::Get("relax.call_tir"); static const Op& call_tir_inplace_op_ = Op::Get("relax.call_tir_inplace"); @@ -220,7 +220,7 @@ class GraphCreator : public ExprVisitor { SetNodePattern(binding_var_node, pattern); // Visit all call args for (const Expr& arg : args) { - ICHECK(IsLeafOrTuple(arg)) + TVM_FFI_ICHECK(IsLeafOrTuple(arg)) << "FuseOps expects all relax::Call nodes to have non-nested arguments, " << "but " << ffi::GetRef(call) << " has argument " << arg << ", which is neither a leaf node nor a relax::Tuple"; @@ -230,7 +230,7 @@ class GraphCreator : public ExprVisitor { void VisitTupleGetItem(const TupleGetItemNode* tuple_item, IndexedForwardGraph::Node* binding_var_node) { - ICHECK_NOTNULL(binding_var_node); + TVM_FFI_ICHECK_NOTNULL(binding_var_node); auto pattern = OpPatternKind::kInjective; if (input_params_.count(tuple_item->tuple.as())) { @@ -244,7 +244,7 @@ class GraphCreator : public ExprVisitor { } void VisitUnsupportedNode(const Expr& expr, IndexedForwardGraph::Node* binding_var_node) { - ICHECK_NOTNULL(binding_var_node); + TVM_FFI_ICHECK_NOTNULL(binding_var_node); SetNodePattern(binding_var_node, OpPatternKind::kOpaque); auto visit_leaves = [this, &binding_var_node](const Expr& e) { @@ -259,7 +259,7 @@ class GraphCreator : public ExprVisitor { void VisitLeaf(const Expr& leaf_expr, IndexedForwardGraph::Node* binding_var_node, const OpPatternKind& pattern) { - ICHECK_NOTNULL(binding_var_node); + TVM_FFI_ICHECK_NOTNULL(binding_var_node); // Recursive visit if it's Tuple if (const auto* tuple = leaf_expr.as()) { @@ -296,7 +296,7 @@ class GraphCreator : public ExprVisitor { * \note The node corresponding to each key is supposed to be created for only once */ IndexedForwardGraph::Node* CreateNode(const Object* key) { - ICHECK(graph_.node_map.find(key) == graph_.node_map.end()) + TVM_FFI_ICHECK(graph_.node_map.find(key) == graph_.node_map.end()) << "The object " << ffi::GetRef(key) << " appears at multiple definition sites."; auto* node = arena_->make(); graph_.node_map[key] = node; @@ -311,15 +311,15 @@ class GraphCreator : public ExprVisitor { */ void AddToPostDFSOrder(IndexedForwardGraph::Node* node, const Object* key) { auto it = graph_.node_map.find(key); - ICHECK(it != graph_.node_map.end() && it->second == node) + TVM_FFI_ICHECK(it != graph_.node_map.end() && it->second == node) << "Cannot add node " << ffi::GetRef(key) << " to the post-DFS order, " << "because the node for this object has not yet been created."; // We only set the reference of the node when adding it to the post-dfs order. Thus, if the // reference of a node is already set, it must have been appended to the post-dfs order. - ICHECK(node->ref == nullptr) << "Cannot add node " << ffi::GetRef(key) - << " to the post-DFS order, " - << "because it has already been added."; + TVM_FFI_ICHECK(node->ref == nullptr) + << "Cannot add node " << ffi::GetRef(key) << " to the post-DFS order, " + << "because it has already been added."; node->ref = key; node->index = graph_.post_dfs_order.size(); @@ -353,7 +353,7 @@ class GraphCreator : public ExprVisitor { * \param pattern The pattern of the node */ void SetNodePattern(IndexedForwardGraph::Node* node, OpPatternKind pattern) { - ICHECK(initialized_nodes_.find(node) == initialized_nodes_.end()) + TVM_FFI_ICHECK(initialized_nodes_.find(node) == initialized_nodes_.end()) << "The input node " << ffi::GetRef(node->ref) << " cannot have have its OpPatternKind set more than once."; initialized_nodes_.insert(node); @@ -398,7 +398,7 @@ class FunctionCreator : public ExprMutator { * // TODO(tvm-team): handle match shape */ void AppendBinding(const Binding& binding) { - ICHECK(!function_.defined()) + TVM_FFI_ICHECK(!function_.defined()) << "The `function_` is supposed to be uncreated when adding bindings"; if (const auto* var_binding = binding.as()) { @@ -411,7 +411,7 @@ class FunctionCreator : public ExprMutator { const Tuple& args = Downcast(call->args[1]); for (const Expr& arg : args->fields) { CheckDefAndUpdateParam(arg); - ICHECK(GetStructInfoAs(arg) == nullptr); + TVM_FFI_ICHECK(GetStructInfoAs(arg) == nullptr); } // TODO(tvm-team): handle shape expr } else { @@ -430,7 +430,7 @@ class FunctionCreator : public ExprMutator { if (auto tuple = arg.as()) { for (const Expr& tup_arg : tuple->fields) { CheckDefAndUpdateParam(tup_arg); - ICHECK(GetStructInfoAs(tup_arg) == nullptr); + TVM_FFI_ICHECK(GetStructInfoAs(tup_arg) == nullptr); } } else { CheckDefAndUpdateParam(arg); @@ -466,7 +466,7 @@ class FunctionCreator : public ExprMutator { /*! \brief Set a var defined in the group as output. */ size_t AppendOutput(const Var& var) { - ICHECK(defined_vars_.count(var.get())); + TVM_FFI_ICHECK(defined_vars_.count(var.get())); auto output_idx = GetOutputIndex(var); if (output_idx) { return *output_idx; @@ -490,7 +490,7 @@ class FunctionCreator : public ExprMutator { // function. std::unordered_map> tuple_get_item_remap; for (auto& [tuple_arg, item_indices] : partially_used_tuple_params_) { - ICHECK(!item_indices.empty()); + TVM_FFI_ICHECK(!item_indices.empty()); int param_idx = tuple_param_idx_[tuple_arg]; Var param = params_[param_idx]; ffi::String param_name = params_[param_idx]->name_hint(); @@ -520,7 +520,7 @@ class FunctionCreator : public ExprMutator { if (const auto* tuple_get_item = var_binding->value.as()) { auto it = tuple_get_item_remap.find(tuple_get_item->tuple.get()); if (it != tuple_get_item_remap.end()) { - ICHECK(it->second.find(tuple_get_item->index) != it->second.end()); + TVM_FFI_ICHECK(it->second.find(tuple_get_item->index) != it->second.end()); var_remap_[var_binding->var->vid] = it->second[tuple_get_item->index]; if (auto output_idx = GetOutputIndex(binding->var)) { outputs.Set(*output_idx, it->second[tuple_get_item->index]); @@ -534,7 +534,7 @@ class FunctionCreator : public ExprMutator { // Case 1. It is an output binding // We only allow VarBinding as output. const auto* var_binding = binding.as(); - ICHECK_NOTNULL(var_binding); + TVM_FFI_ICHECK_NOTNULL(var_binding); Var output_var = builder_->EmitOutput(VisitExpr(var_binding->value)); var_remap_[var_binding->var->vid] = output_var; outputs.Set(*output_idx, output_var); @@ -748,8 +748,8 @@ class OperatorFusor : public ExprMutator { GroupMap obj2group; for (int nid = 0; nid < static_cast(graph.post_dfs_order.size()); ++nid) { Group* group_root = groups[nid]->FindRoot(); - ICHECK(group_root != nullptr); - ICHECK(graph.post_dfs_order[nid]->ref != nullptr); + TVM_FFI_ICHECK(group_root != nullptr); + TVM_FFI_ICHECK(graph.post_dfs_order[nid]->ref != nullptr); obj2group[graph.post_dfs_order[nid]->ref] = group_root; } return obj2group; @@ -757,7 +757,7 @@ class OperatorFusor : public ExprMutator { bool IsTupleOutput(Function f) { auto sinfo = GetStructInfo(f).as(); - ICHECK(sinfo); + TVM_FFI_ICHECK(sinfo); return sinfo->ret->IsInstance(); } @@ -813,7 +813,7 @@ class OperatorFusor : public ExprMutator { } const auto& it_creator = group2func_.find(group); - ICHECK(it_creator != group2func_.end()); + TVM_FFI_ICHECK(it_creator != group2func_.end()); const FunctionCreator& func_info = it_creator->second; if (!func_info.function_.defined()) { @@ -841,8 +841,9 @@ class OperatorFusor : public ExprMutator { // Case 3. The binding is the last binding of the group. const auto* var_binding = binding.as(); - ICHECK(var_binding != nullptr) << "The last binding of a group whose size is larger than 1 " - "is supposed to be a variable binding"; + TVM_FFI_ICHECK(var_binding != nullptr) + << "The last binding of a group whose size is larger than 1 " + "is supposed to be a variable binding"; // Step a. Add the grouped function to the IRModule GlobalVar gv = builder_->AddFunction(func, func_info.name_hint_); @@ -916,7 +917,7 @@ class OperatorFusor : public ExprMutator { // Skip the vars from input or groups with single binding. if (producer_group != cur_group) { for (Group* depgroup : group_deps_[producer_group]) { - ICHECK(depgroup != cur_group) + TVM_FFI_ICHECK(depgroup != cur_group) << "A cyclic dependency detected between the groups " << binding->var->name_hint() << " and " << used_var->name_hint() << " are in."; } @@ -935,7 +936,7 @@ class OperatorFusor : public ExprMutator { PostOrderVisit(var_binding->value, update_boundary); } else { const auto* match_cast = binding.as(); - ICHECK_NOTNULL(match_cast); + TVM_FFI_ICHECK_NOTNULL(match_cast); PostOrderVisit(match_cast->value, update_boundary); } } @@ -958,7 +959,7 @@ class OperatorFusor : public ExprMutator { */ Group* GetGroupFromVar(const Var& var) { const auto& it_group = obj2group_.find(var.get()); - ICHECK(it_group != obj2group_.end()) + TVM_FFI_ICHECK(it_group != obj2group_.end()) << "Variable " << var << " could not be found in any group"; Group* group = it_group->second; return group->FindRoot(); @@ -1152,7 +1153,7 @@ class PatternBasedPartitioner : ExprVisitor { // parent_group corresponds to the group of "conv1" above. auto parent_group = GetGroupForBoundVar(binding->var); - ICHECK(parent_group); + TVM_FFI_ICHECK(parent_group); parent_group->attrs.Set(attr::kComposite, pat_name_); if (attrs_getter_ != nullptr) { const auto& custom_attrs = attrs_getter_(context->annotated_expr); @@ -1186,7 +1187,7 @@ class PatternBasedPartitioner : ExprVisitor { } Group* GetGroupForBoundVar(const Var& bound_var) { - ICHECK(group_map_.count(bound_var.get())); + TVM_FFI_ICHECK(group_map_.count(bound_var.get())); return group_map_[bound_var.get()]->FindRoot(); } @@ -1354,7 +1355,8 @@ IRModule FuseOpsByPattern(const tvm::ffi::Array& patte for (const auto& name : entry_function_names) { auto gv = mod->GetGlobalVar(name); auto func = mod->Lookup(gv); - ICHECK(func->IsInstance()) << "Entry function must be a relax function"; + TVM_FFI_ICHECK(func->IsInstance()) + << "Entry function must be a relax function"; entry_functions.push_back(Downcast(func)); } } else { @@ -1378,8 +1380,7 @@ IRModule FuseOpsByPattern(const tvm::ffi::Array& patte pattern->name, pattern->pattern, pattern->annotation_patterns, pattern->check.value_or(nullptr), func, &arena, pattern->attrs_getter.value_or(nullptr)); for (const auto& [key, value] : map) { - CHECK(!group_map.count(key)) - << "ValueError: " + TVM_FFI_CHECK(!group_map.count(key), ValueError) << "IRModule is invalid. " << "The object " << ffi::GetRef(key) << " appears in multiple partitions, " << "which can occur when the IRModule was not single-site assignment"; diff --git a/src/relax/transform/fuse_tir.cc b/src/relax/transform/fuse_tir.cc index 2c25b66b59c7..4a36047906c2 100644 --- a/src/relax/transform/fuse_tir.cc +++ b/src/relax/transform/fuse_tir.cc @@ -43,7 +43,7 @@ class SymbolicMatcher : ExprFunctor& params, const ffi::Array& args) { - CHECK_EQ(params.size(), args.size()); + TVM_FFI_ICHECK_EQ(params.size(), args.size()); for (size_t i = 0; i < params.size(); ++i) { Match(params[i], args[i]); } @@ -51,7 +51,7 @@ class SymbolicMatcher : ExprFunctorSimplify(Substitute(must_prove_, *var_remap_)); - CHECK(!is_zero(must_prove_)); + TVM_FFI_ICHECK(!is_zero(must_prove_)); } private: @@ -59,8 +59,9 @@ class SymbolicMatcher : ExprFunctor(); if (!rhs || (op->value != rhs->value)) { - LOG(FATAL) << "Parameter expression " << ffi::GetRef(op) - << " expected an integer argument with value " << op->value << ", " - << "but was provided with the argument " << other; + TVM_FFI_THROW(InternalError) + << "Parameter expression " << ffi::GetRef(op) + << " expected an integer argument with value " << op->value << ", " + << "but was provided with the argument " << other; } } void VisitExpr_(const FloatImmNode* op, const PrimExpr& other) { const auto* rhs = other.as(); if (!rhs || (op->value != rhs->value)) { - LOG(FATAL) << "Parameter expression " << ffi::GetRef(op) - << " expected an float argument with value " << op->value << ", " - << "but was provided with the argument " << other; + TVM_FFI_THROW(InternalError) << "Parameter expression " << ffi::GetRef(op) + << " expected an float argument with value " << op->value << ", " + << "but was provided with the argument " << other; } } void VisitExpr_(const CastNode* op, const PrimExpr& other) { const auto* rhs = other.as(); if (!rhs) { - LOG(FATAL) << "Parameter expression " << ffi::GetRef(op) << " expected an cast to " - << op->dtype << " as the argument, " - << "but was provided with the argument " << other; + TVM_FFI_THROW(InternalError) << "Parameter expression " << ffi::GetRef(op) + << " expected an cast to " << op->dtype << " as the argument, " + << "but was provided with the argument " << other; } VisitExpr(op->value, rhs->value); } @@ -129,9 +131,9 @@ class SymbolicMatcher : ExprFunctordtype.code() != rhs->dtype.code()) { - LOG(FATAL) << "Parameter expression " << ffi::GetRef(op) << " with dtype " - << op->dtype << " cannot match to argument " << rhs << " with dtype " - << rhs.dtype(); + TVM_FFI_THROW(InternalError) + << "Parameter expression " << ffi::GetRef(op) << " with dtype " << op->dtype + << " cannot match to argument " << rhs << " with dtype " << rhs.dtype(); } else if (auto it = var_remap_->find(lhs); it != var_remap_->end()) { VisitExpr((*it).second, rhs); } else { @@ -171,7 +173,7 @@ class FuseTIRBufferSubstitutor : private StmtExprMutator { Stmt Substitute(Stmt stmt) { return this->VisitStmt(std::move(stmt)); } Buffer SubstituteAllocatedBuffer(Buffer buffer) { - ICHECK(buffer_remap_.find(buffer) == buffer_remap_.end()); + TVM_FFI_ICHECK(buffer_remap_.find(buffer) == buffer_remap_.end()); ffi::Array shape = MutateArray(buffer->shape, [this](const PrimExpr& expr) { return this->VisitExpr(expr); }); ffi::Array strides = MutateArray( @@ -398,8 +400,8 @@ class SBlockNameDeduplicator : public tir::StmtMutator { return candidate; } ++counter; - ICHECK_GT(counter, 0) << "Counter overflow when generating unique block name for prefix: " - << prefix; + TVM_FFI_ICHECK_GT(counter, 0) + << "Counter overflow when generating unique block name for prefix: " << prefix; } } @@ -420,7 +422,8 @@ static ffi::Array GetInplaceOutputIndices(const ffi::Array& in if (i >= 0) { ret.push_back(Integer(i)); } else { - CHECK_EQ(i, -1) << "The only negative index expected in inplace_indices is -1, but got " << i; + TVM_FFI_ICHECK_EQ(i, -1) + << "The only negative index expected in inplace_indices is -1, but got " << i; ret.push_back(Integer(last_idx)); last_idx++; } @@ -448,7 +451,7 @@ class RelaxToTIRVarMapCollector : public ExprVisitor { static const Op& call_tir_op_ = Op::Get("relax.call_tir"); static const Op& call_tir_inplace_op_ = Op::Get("relax.call_tir_inplace"); - ICHECK(call->op == call_tir_op_ || call->op == call_tir_inplace_op_) + TVM_FFI_ICHECK(call->op == call_tir_op_ || call->op == call_tir_inplace_op_) << "Only call_tir and call_tir_inplace are supported in primitive function, but got: " << ffi::GetRef(call); CollectVarMapping(call, current_var_, call->op == call_tir_inplace_op_); @@ -466,7 +469,8 @@ class RelaxToTIRVarMapCollector : public ExprVisitor { if (lhs_var->IsInstance()) { relax_results = Downcast(lhs_var)->fields; } else { - CHECK(lhs_var->IsInstance()) << "The lhs_var is expected to be either tuple or var"; + TVM_FFI_ICHECK(lhs_var->IsInstance()) + << "The lhs_var is expected to be either tuple or var"; relax_results = {Downcast(lhs_var)}; } @@ -476,7 +480,7 @@ class RelaxToTIRVarMapCollector : public ExprVisitor { ffi::Array output_idxs; if (in_place) { const auto* attrs = call->attrs.as(); - CHECK(attrs) << "Must have CallTIRInplaceAttrs for an in-place call"; + TVM_FFI_ICHECK(attrs) << "Must have CallTIRInplaceAttrs for an in-place call"; output_idxs = GetInplaceOutputIndices(attrs->inplace_indices, num_inputs); } else { for (size_t i = num_inputs; i < num_inputs + num_outputs; i++) { @@ -488,7 +492,7 @@ class RelaxToTIRVarMapCollector : public ExprVisitor { // structurally equal to the `new_buf` passed auto ValidateBufferCompatibility = [this](tir::Buffer new_buf, Expr expr) { if (auto it = relax_to_tir_var_map_.find(expr); it != relax_to_tir_var_map_.end()) { - ICHECK(StructuralEqual()((*it).second, new_buf)) + TVM_FFI_ICHECK(StructuralEqual()((*it).second, new_buf)) << "Inconsistent buffers " << (*it).second << " and " << new_buf << " mapped to the same relax var: " << expr; } @@ -531,9 +535,9 @@ class FusedTIRConstructor : public ExprVisitor { const GlobalVar& gv) { FusedTIRConstructor visitor(mod, gv->name_hint); BaseFunc f = mod->Lookup(gv); - CHECK(f->IsInstance()) + TVM_FFI_ICHECK(f->IsInstance()) << "Expected relax functions, but got: " << f->GetTypeKey(); - CHECK(f->HasNonzeroAttr(relax::attr::kPrimitive)) + TVM_FFI_ICHECK(f->HasNonzeroAttr(relax::attr::kPrimitive)) << "Expected a function with attr `kPrimitive`"; visitor(Downcast(f)); ffi::Array inplace_indices; @@ -599,7 +603,7 @@ class FusedTIRConstructor : public ExprVisitor { // Step 3. Create and remap buffers for function output Expr body = func->body->body; auto it = func_info_.expr2buffers.find(body); - ICHECK(it != func_info_.expr2buffers.end()) + TVM_FFI_ICHECK(it != func_info_.expr2buffers.end()) << "Fail to detect output buffers for function body"; const ffi::Array& buffers = (*it).second; @@ -625,7 +629,7 @@ class FusedTIRConstructor : public ExprVisitor { // in duplicates in the buffer map otherwise) if (auto it = buffer_to_idx.find(buffers[i]); it != buffer_to_idx.end()) { auto idx = (*it).second; - CHECK(!inplace_indices_.count(idx)) + TVM_FFI_ICHECK(!inplace_indices_.count(idx)) << "In-place index " << idx << " used twice! An argument must be aliased."; inplace_indices_.insert(idx); continue; @@ -657,12 +661,12 @@ class FusedTIRConstructor : public ExprVisitor { // assign binding var to the buffers of the value func_info_.expr2buffers.Set(binding->var, (*it).second); } else { - LOG(FATAL) << "Unsupported binding value: " << binding->value; + TVM_FFI_THROW(InternalError) << "Unsupported binding value: " << binding->value; } } void VisitBinding_(const MatchCastNode* match_cast) final { - LOG(FATAL) << "MatchCast is unsupported in primitive functions"; + TVM_FFI_THROW(InternalError) << "MatchCast is unsupported in primitive functions"; } void VisitExpr_(const CallNode* call) final { @@ -670,7 +674,7 @@ class FusedTIRConstructor : public ExprVisitor { static const Op& call_tir_op_ = Op::Get("relax.call_tir"); static const Op& call_tir_inplace_op_ = Op::Get("relax.call_tir_inplace"); - ICHECK(call->op == call_tir_op_ || call->op == call_tir_inplace_op_) + TVM_FFI_ICHECK(call->op == call_tir_op_ || call->op == call_tir_inplace_op_) << "Only call_tir and call_tir_inplace are supported in primitive function, but got: " << ffi::GetRef(call); @@ -683,7 +687,7 @@ class FusedTIRConstructor : public ExprVisitor { // Step 3. Check functions are all schedulable funcs. i.e. the body of func is root block // TODO(Siyuan): support un-schedulable functions. - ICHECK(prim_func->body->IsInstance()) + TVM_FFI_ICHECK(prim_func->body->IsInstance()) << "Only schedulable functions (whose body is the root block) can be fused"; const tir::SBlockRealize& root_realize = Downcast(prim_func->body); const tir::SBlock& root_block = root_realize->block; @@ -702,18 +706,19 @@ class FusedTIRConstructor : public ExprVisitor { // Step 6. Update tir_vars if (call->args.size() > 2) { - ICHECK(call->args.size() == 3); + TVM_FFI_ICHECK(call->args.size() == 3); const Expr& tir_vars = call->args[2]; if (const auto* shape_expr = tir_vars.as()) { const auto& args = shape_expr->values; size_t num_params = prim_func->params.size(); - ICHECK_GE(num_params, args.size()); + TVM_FFI_ICHECK_GE(num_params, args.size()); for (size_t i = 0; i < args.size(); ++i) { const tir::Var& param = prim_func->params[num_params - args.size() + i]; func_info_.symbolic_var_matcher.Match(param, args[i]); } } else { - LOG(FATAL) << "TIR vars should be a shape expr, but got: " << tir_vars->GetTypeKey(); + TVM_FFI_THROW(InternalError) + << "TIR vars should be a shape expr, but got: " << tir_vars->GetTypeKey(); } } // Update fused func name @@ -753,7 +758,7 @@ class FusedTIRConstructor : public ExprVisitor { } void VisitExpr_(const ConstantNode* op) final { - LOG(FATAL) << "Relax.Constant is not supported in primitive functions."; + TVM_FFI_THROW(InternalError) << "Relax.Constant is not supported in primitive functions."; } /*! @@ -763,29 +768,33 @@ class FusedTIRConstructor : public ExprVisitor { static ffi::Array> GetCallTIROutputShapes(const CallNode* call) { static const Op& call_tir_op_ = Op::Get("relax.call_tir"); static const Op& call_tir_inplace_op_ = Op::Get("relax.call_tir_inplace"); - ICHECK(call->op.same_as(call_tir_op_) || call->op.same_as(call_tir_inplace_op_)); - ICHECK_EQ(call->sinfo_args.size(), 1); - auto get_tensor_shape = [](const TensorStructInfoNode* sinfo) { - const auto* shape_expr = sinfo->shape.as(); - CHECK(shape_expr) << "FuseTIR expects all parameters are Tensors with symbolic shape."; - return shape_expr->values; - }; + TVM_FFI_ICHECK(call->op.same_as(call_tir_op_) || call->op.same_as(call_tir_inplace_op_)); + TVM_FFI_ICHECK_EQ(call->sinfo_args.size(), 1); + auto get_tensor_shape = + [](const TensorStructInfoNode* sinfo) { + const auto* shape_expr = sinfo->shape.as(); + TVM_FFI_ICHECK(shape_expr) + << "FuseTIR expects all parameters are Tensors with symbolic shape."; + return shape_expr->values; + }; if (const auto* tuple_sinfo = call->sinfo_args[0].as()) { ffi::Array> shapes; for (const StructInfo& field : tuple_sinfo->fields) { const auto* tensor_sinfo = field.as(); - CHECK(tensor_sinfo) << "CallTIR sinfo_args are expected to be TensorStructInfo or Tuple of " - "TensorStructInfo, but got " - << call->sinfo_args[0]; + TVM_FFI_ICHECK(tensor_sinfo) + << "CallTIR sinfo_args are expected to be TensorStructInfo or Tuple of " + "TensorStructInfo, but got " + << call->sinfo_args[0]; shapes.push_back(get_tensor_shape(tensor_sinfo)); } return shapes; } else if (const auto* tensor_sinfo = call->sinfo_args[0].as()) { return {get_tensor_shape(tensor_sinfo)}; } else { - CHECK(tensor_sinfo) << "CallTIR sinfo_args are expected to be TensorStructInfo or Tuple of " - "TensorStructInfo, but got " - << call->sinfo_args[0]; + TVM_FFI_ICHECK(tensor_sinfo) + << "CallTIR sinfo_args are expected to be TensorStructInfo or Tuple of " + "TensorStructInfo, but got " + << call->sinfo_args[0]; throw; } } @@ -799,7 +808,7 @@ class FusedTIRConstructor : public ExprVisitor { // Substitute the buffer with the already allocated one if it is an intermediate var if (it != func_info_.expr2buffers.end()) { for (const tir::Buffer& target_buffer : (*it).second) { - ICHECK_LT(buffer_idx, buffers.size()); + TVM_FFI_ICHECK_LT(buffer_idx, buffers.size()); const tir::Buffer& buffer = buffers[buffer_idx]; func_info_.symbolic_var_matcher.Match(buffer->shape, target_buffer->shape); func_info_.buffer_subst_map.Set(buffer, target_buffer); @@ -809,7 +818,7 @@ class FusedTIRConstructor : public ExprVisitor { } } // Make sure every buffer is mapped. - ICHECK_EQ(buffer_idx, buffers.size()); + TVM_FFI_ICHECK_EQ(buffer_idx, buffers.size()); } /*! @@ -826,7 +835,7 @@ class FusedTIRConstructor : public ExprVisitor { arg_list = {args}; } - ICHECK_GE(func->params.size(), arg_list.size()); + TVM_FFI_ICHECK_GE(func->params.size(), arg_list.size()); for (size_t i = 0; i < arg_list.size(); ++i) { const tir::Var& param = func->params[i]; const tir::Buffer& buffer = func->buffer_map.at(param); @@ -841,7 +850,7 @@ class FusedTIRConstructor : public ExprVisitor { size_t n = func->params.size(); int symbolic_var_index = -1; size_t output_size = output_indices.size(); - ICHECK_GE(n, output_size); + TVM_FFI_ICHECK_GE(n, output_size); ffi::Array ret; for (auto idx : output_indices) { @@ -850,17 +859,19 @@ class FusedTIRConstructor : public ExprVisitor { if (param->dtype.is_int() || param->dtype.is_uint()) { if (symbolic_var_index == -1) symbolic_var_index = i; } else if (param->dtype.is_handle()) { - CHECK(symbolic_var_index == -1) << "The scalar input should be at the ending of the " - "parameter list."; + TVM_FFI_ICHECK(symbolic_var_index == -1) + << "The scalar input should be at the ending of the " + "parameter list."; ret.push_back(param); } else { - LOG(FATAL) << "The params of PrimFunc are expected to be Buffer handle or scalar, but got: " - << param->dtype; + TVM_FFI_THROW(InternalError) + << "The params of PrimFunc are expected to be Buffer handle or scalar, but got: " + << param->dtype; } } size_t end_index = symbolic_var_index == -1 ? n : symbolic_var_index; - ICHECK_GE(end_index, output_size); + TVM_FFI_ICHECK_GE(end_index, output_size); return ret; } @@ -878,12 +889,12 @@ class FusedTIRConstructor : public ExprVisitor { size_t n = func->params.size(); int num_inputs = Downcast(call->args[1])->fields.size(); size_t output_size = output_shapes.size(); - ICHECK_GE(n, output_size); + TVM_FFI_ICHECK_GE(n, output_size); ffi::Array output_buffers; ffi::Array output_idxs; if (is_inplace) { const auto* attrs = call->attrs.as(); - CHECK(attrs) << "Must have CallTIRInplaceAttrs for an in-place call"; + TVM_FFI_ICHECK(attrs) << "Must have CallTIRInplaceAttrs for an in-place call"; output_idxs = GetInplaceOutputIndices(attrs->inplace_indices, num_inputs); } else { for (size_t i = 0; i < output_size; i++) { @@ -899,7 +910,8 @@ class FusedTIRConstructor : public ExprVisitor { // if this is an inplace output, do not do an intermediate allocation if (output_idxs[i].IntValue() < num_inputs) { - CHECK(input_buffers.has_value()) << "Inplace functions must have some defined input"; + TVM_FFI_ICHECK(input_buffers.has_value()) + << "Inplace functions must have some defined input"; output_buffers.push_back(input_buffers.value()[output_idxs[i].IntValue()]); continue; } @@ -946,8 +958,7 @@ class FusedTIRConstructor : public ExprVisitor { const ffi::Optional& tir_buffer_param) { auto struct_info = GetStructInfo(relax_param); - CHECK(!struct_info.as()) - << "InternalError: " + TVM_FFI_CHECK(!struct_info.as(), InternalError) << "All tuple parameters should be expanded before this point in FuseTIR. " << "However, parameter " << relax_param << " has struct info " << struct_info; @@ -956,7 +967,7 @@ class FusedTIRConstructor : public ExprVisitor { if (const auto* tensor = struct_info.as()) { // Case 1. The relax param is a Tensor, we directly create a tir var and buffer const auto* shape_expr = tensor->shape.as(); - ICHECK(shape_expr) << "FuseTIR expects all Tensor parameters have a known shape."; + TVM_FFI_ICHECK(shape_expr) << "FuseTIR expects all Tensor parameters have a known shape."; DataType dtype = tensor->dtype; tir::Buffer buffer; if (tir_buffer_param.defined()) { @@ -970,20 +981,19 @@ class FusedTIRConstructor : public ExprVisitor { } else if (const auto* prim_value = struct_info.as()) { // Case 2. The relax param is a scalar, we directly create a tir var - ICHECK(prim_value->value->IsInstance()); + TVM_FFI_ICHECK(prim_value->value->IsInstance()); out->push_back(Downcast(prim_value->value)); } else if (const auto* shape_expr = struct_info.as()) { // Case 3. The relax param is a tuple of scalars, each represented as a tir var for (const auto& var : shape_expr->values.value()) { - ICHECK(var->IsInstance()); + TVM_FFI_ICHECK(var->IsInstance()); out->push_back(Downcast(var)); } } else { - LOG(FATAL) << "TypeError: " - << "The param type of PrimFunc is expected to be " - << "Tensor, PrimValue, or ShapeExpr, " - << "but got " << struct_info->GetTypeKey(); + TVM_FFI_THROW(TypeError) << "The param type of PrimFunc is expected to be " + << "Tensor, PrimValue, or ShapeExpr, " + << "but got " << struct_info->GetTypeKey(); } } @@ -995,7 +1005,7 @@ class FusedTIRConstructor : public ExprVisitor { ffi::Map attr_map; attr_map.Set(tir::attr::kNoAlias, true); tir::FuseTIRBufferSubstitutor subst(func_info_.buffer_subst_map, func_info_.symbolic_var_remap); - ICHECK(func_info_.global_name != "fused"); + TVM_FFI_ICHECK(func_info_.global_name != "fused"); // Remove output buffers from func_info_.alloc_buffers ffi::Array alloc_buffers; for (const tir::Buffer& buf : func_info_.alloc_buffers) { @@ -1025,7 +1035,7 @@ class FusedTIRConstructor : public ExprVisitor { } return num; } else { - LOG(FATAL) << "TensorType and TupleType are expect, but got: " << sinfo; + TVM_FFI_THROW(InternalError) << "TensorType and TupleType are expect, but got: " << sinfo; return 0; } } @@ -1153,7 +1163,7 @@ class TIRFuseMutator : public ExprMutator { for (const auto& [gv, func] : mod->functions) { if (func->IsInstance()) { - ICHECK(!func->HasNonzeroAttr(attr::kPrimitive)) + TVM_FFI_ICHECK(!func->HasNonzeroAttr(attr::kPrimitive)) << "Module should not contain any primitive relax functions at this point"; relax::Function update_func = Downcast(mutator.VisitExpr(func)); if (!update_func.same_as(func)) { @@ -1186,9 +1196,9 @@ class TIRFuseMutator : public ExprMutator { return Tuple(fields); } else { auto* tensor = sinfo.as(); - ICHECK(tensor) << "FuseTIR can only take tensor or tuple type"; + TVM_FFI_ICHECK(tensor) << "FuseTIR can only take tensor or tuple type"; auto* shape_expr = tensor->shape.as(); - ICHECK(shape_expr) << "FuseTIR requires all intermediate values have shape"; + TVM_FFI_ICHECK(shape_expr) << "FuseTIR requires all intermediate values have shape"; return ffi::GetRef(shape_expr); } } @@ -1228,9 +1238,9 @@ class TIRFuseMutator : public ExprMutator { auto arg = call->args[i]; auto sinfo = GetStructInfo(arg); - ICHECK(!relax_func->params[i]->struct_info_->IsInstance() && - !sinfo.as()) - << "InternalError: " + TVM_FFI_CHECK(!relax_func->params[i]->struct_info_->IsInstance() && + !sinfo.as(), + InternalError) << "All tuple parameters should be expanded before this point in FuseTIR. " << "However, argument " << arg << " with struct info " << arg->struct_info_ << " is passed as argument " << i << " to Primitive Relax function " << old_gvar @@ -1238,18 +1248,20 @@ class TIRFuseMutator : public ExprMutator { << relax_func->params[i]->struct_info_; if (const auto* shape = sinfo.as()) { - CHECK(shape->values.defined()) << "FuseTIR requires all shape input has struct_info value."; + TVM_FFI_ICHECK(shape->values.defined()) + << "FuseTIR requires all shape input has struct_info value."; for (const PrimExpr& prim_value : shape->values.value()) { - CHECK(prim_value->IsInstance()) + TVM_FFI_ICHECK(prim_value->IsInstance()) << "All shape inputs are expected to be single tir var."; tir_vars.push_back(prim_value); } } else if (const auto* prim_value = sinfo.as()) { - CHECK(prim_value->value.defined()) + TVM_FFI_ICHECK(prim_value->value.defined()) << "FuseTIR requires all R.Prim arguments to have a known value."; PrimExpr expr = prim_value->value.value(); - CHECK(expr->IsInstance()) << "FuseTIR currently requires all R.Prim " - "arguments to provide a single tir::Var."; + TVM_FFI_ICHECK(expr->IsInstance()) + << "FuseTIR currently requires all R.Prim " + "arguments to provide a single tir::Var."; tir_vars.push_back(expr); } else { diff --git a/src/relax/transform/gradient.cc b/src/relax/transform/gradient.cc index 15bf6a273a3f..90e0619f5dd5 100644 --- a/src/relax/transform/gradient.cc +++ b/src/relax/transform/gradient.cc @@ -129,7 +129,7 @@ class CheckpointCollector : private ExprMutator { // 1) the output of end_checkpoint; 2) checkpointed // then the variable of binding will be checkpointed auto var_binding = binding.as(); - ICHECK(var_binding); + TVM_FFI_ICHECK(var_binding); auto value_call = var_binding->value.as(); if (!value_call || (value_call->op != s_cp && value_call->op != e_cp)) { @@ -157,8 +157,8 @@ class CheckpointCollector : private ExprMutator { if (value->op == s_cp || value->op == e_cp) { // Eliminate the binding auto var = value->args[0].as(); - ICHECK(var) << "The first argument of relax.grad.start_checkpoint and " - "relax.grad.end_checkpoint should be a Var"; + TVM_FFI_ICHECK(var) << "The first argument of relax.grad.start_checkpoint and " + "relax.grad.end_checkpoint should be a Var"; // var might already be remapped. Find the original var auto orig_var = Downcast(ExprMutator::VisitExpr(ffi::GetRef(var))); // Add remapping from binding->var to new_var @@ -213,7 +213,7 @@ class CheckpointGenerator : private ExprMutator { for (auto binding : forward_block->bindings) { auto* var_binding = binding.as(); - CHECK(var_binding) << "Now only support VarBindingNode"; + TVM_FFI_ICHECK(var_binding) << "Now only support VarBindingNode"; auto var = var_binding->var; binding_map_.Set(var, var_binding->value); if (checkpoints.count(var->vid)) { @@ -323,7 +323,7 @@ class BackwardBindingGenerator : private ExprVisitor { void VisitBinding(const Binding& binding) final { // TODO(chaofan, yixin): support other types of bindings - CHECK(binding->IsInstance()) << "Now only support VarBindingNode"; + TVM_FFI_ICHECK(binding->IsInstance()) << "Now only support VarBindingNode"; auto* var_binding = binding.as(); if (adjoint_var_map_.count(var_binding->var) == 0) { @@ -333,9 +333,9 @@ class BackwardBindingGenerator : private ExprVisitor { Expr value = var_binding->value; // TODO(chaofan, yixin): support other types of binding values - CHECK(value->IsInstance() || value->IsInstance() || - value->IsInstance() || value->IsInstance() || - value->IsInstance()) + TVM_FFI_ICHECK(value->IsInstance() || value->IsInstance() || + value->IsInstance() || value->IsInstance() || + value->IsInstance()) << "Now does not support the type of binding value: " << value; ExprVisitor::VisitBinding_(var_binding); @@ -361,8 +361,9 @@ class BackwardBindingGenerator : private ExprVisitor { checkpoint_generator_.UpdateBinding(binding->var, ffi::GetRef(call)); if (call_op == Op::Get("relax.call_tir")) { - LOG(FATAL) << "Differentiation of call_tir op without registering corresponding gradient " - "function is not supported yet."; + TVM_FFI_THROW(InternalError) + << "Differentiation of call_tir op without registering corresponding gradient " + "function is not supported yet."; } else if (call_op == Op::Get("relax.call_tir_with_grad")) { // tir gradient registering auto te_grad_name = call->attrs.as()->te_grad_name; @@ -375,10 +376,10 @@ class BackwardBindingGenerator : private ExprVisitor { auto* tuple_sinfo = GetStructInfoAs(partials); if (!tuple_sinfo) { // result_var is a tensor - ICHECK(args->fields.size() == 1); + TVM_FFI_ICHECK(args->fields.size() == 1); UpdateAdjoint(args->fields[0], partials); } else { - ICHECK(args->fields.size() == tuple_sinfo->fields.size()); + TVM_FFI_ICHECK(args->fields.size() == tuple_sinfo->fields.size()); for (int i = 0; i < static_cast(args->fields.size()); ++i) { UpdateAdjoint(args->fields[i], TupleGetItem(partials, i)); } @@ -386,7 +387,7 @@ class BackwardBindingGenerator : private ExprVisitor { } else { const ffi::Array& partials = gradient_op_map[call_op]( checkpoint_var, Downcast(checkpoint_call), adjoint_var, builder_); - ICHECK(partials.size() == call->args.size()) << "partials number != inputs number"; + TVM_FFI_ICHECK(partials.size() == call->args.size()) << "partials number != inputs number"; for (size_t i = 0; i < partials.size(); ++i) { Expr partial = partials[i]; if (IsCallNoGrad(partial)) { // no grad: don't update @@ -415,10 +416,10 @@ class BackwardBindingGenerator : private ExprVisitor { // a_adjoint[0] += b_adjoint_var // If a_adjoint does not exist, we would create a zeros tuple as a_adjoint first, and then add void VisitBinding_(const VarBindingNode* binding, const TupleGetItemNode* tuple_get_item) final { - ICHECK(tuple_get_item->tuple->IsInstance()) + TVM_FFI_ICHECK(tuple_get_item->tuple->IsInstance()) << "The tuple field of a TupleGetItem is not bound to a Var"; auto* tuple_sinfo = GetStructInfoAs(tuple_get_item->tuple); - ICHECK(tuple_sinfo) << "The tuple field of a TupleGetItem must has a TupleStructInfo"; + TVM_FFI_ICHECK(tuple_sinfo) << "The tuple field of a TupleGetItem must has a TupleStructInfo"; const Var& tuple_var = Downcast(tuple_get_item->tuple); if (adjoint_var_map_.count(tuple_var) == 0) { @@ -467,10 +468,11 @@ class BackwardBindingGenerator : private ExprVisitor { // nothing to do } else if (leaf->IsInstance()) { // must be no grad - ICHECK(IsCallNoGrad(partial)); + TVM_FFI_ICHECK(IsCallNoGrad(partial)); } else { - LOG(FATAL) << "UpdateAdjoint: leaf type not supported. Currently Var and Constant leaves " - "are supported."; + TVM_FFI_THROW(InternalError) + << "UpdateAdjoint: leaf type not supported. Currently Var and Constant leaves " + "are supported."; } }); } @@ -522,7 +524,7 @@ class BackwardBindingGenerator : private ExprVisitor { static Expr AdjointMsgToExpr(AdjointMsg msg) { return NestedMsgToExpr(msg, [](ffi::Optional leaf_expr) { if (!leaf_expr.defined()) { - LOG(FATAL) << "Null should not exist in AdjointMsg."; + TVM_FFI_THROW(InternalError) << "Null should not exist in AdjointMsg."; } return leaf_expr.value(); }); @@ -530,7 +532,7 @@ class BackwardBindingGenerator : private ExprVisitor { static AdjointMsg ExprToAdjointMsg(Expr expr) { return MapToNestedMsgBySInfo(expr, [](Expr leaf) { - ICHECK(GetStructInfoAs(leaf)) + TVM_FFI_ICHECK(GetStructInfoAs(leaf)) << "The leaf of adjoint: " << leaf << " should have StructInfo and be a Tensor."; return AdjointMsg(leaf); }); @@ -541,8 +543,8 @@ class BackwardBindingGenerator : private ExprVisitor { static Expr NestedZeros(const StructInfo& sinfo) { AdjointMsg msg = MapToNestedMsg(sinfo, [](StructInfo sinfo) { auto* tensor_sinfo = sinfo.as(); - ICHECK(tensor_sinfo) << "The leaf of adjoint should be a Tensor."; - ICHECK(tensor_sinfo->shape.defined()) << "Missing shape when building zeros tuple."; + TVM_FFI_ICHECK(tensor_sinfo) << "The leaf of adjoint should be a Tensor."; + TVM_FFI_ICHECK(tensor_sinfo->shape.defined()) << "Missing shape when building zeros tuple."; const Expr& init = zeros(tensor_sinfo->shape.value(), tensor_sinfo->dtype); return init; }); @@ -555,8 +557,8 @@ class BackwardBindingGenerator : private ExprVisitor { AdjointMsg res = CombineNestedMsg( ExprToAdjointMsg(lhs), ExprToAdjointMsg(rhs), [](Expr l_leaf, Expr r_leaf) { auto* sinfo = GetStructInfoAs(l_leaf); - ICHECK(sinfo) << "The leaf of adjoint should have StructInfo and be a Tensor."; - ICHECK(GetStructInfoAs(r_leaf)) + TVM_FFI_ICHECK(sinfo) << "The leaf of adjoint should have StructInfo and be a Tensor."; + TVM_FFI_ICHECK(GetStructInfoAs(r_leaf)) << "The leaf of adjoint should have StructInfo and be a Tensor."; Expr res = add(l_leaf, r_leaf); UpdateStructInfo(res, ffi::GetRef(sinfo)); @@ -573,8 +575,8 @@ class BackwardBindingGenerator : private ExprVisitor { // Step 3) tuple_new = (t1, t2_new, t3) static Expr AddInTuple(const Expr& tuple, int index, const Expr& increment) { auto* sinfo = GetStructInfoAs(tuple); - ICHECK(sinfo) << "The first argument of AddInTuple should have tuple struct info."; - ICHECK(index >= 0 && index < static_cast(sinfo->fields.size())); + TVM_FFI_ICHECK(sinfo) << "The first argument of AddInTuple should have tuple struct info."; + TVM_FFI_ICHECK(index >= 0 && index < static_cast(sinfo->fields.size())); ffi::Array res; for (size_t i = 0; i < sinfo->fields.size(); ++i) { Expr field; @@ -607,7 +609,7 @@ class GradientMutator : private ExprMutator { ffi::Optional> require_grads, int target_index) { // Step 1. Copy function auto* old_func = mod->Lookup(func_name).as(); - CHECK(old_func) << func_name << "is not a Relax Function"; + TVM_FFI_ICHECK(old_func) << func_name << "is not a Relax Function"; auto copier = FunctionCopier(); auto new_func = copier.Copy(ffi::GetRef(old_func)); @@ -674,9 +676,9 @@ class GradientMutator : private ExprMutator { Expr VisitExpr_(const SeqExprNode* seq_expr) final { // TODO(chaofan, yixin): multiple blocks AD - CHECK(seq_expr->blocks.size() == 1) << "now only support one dataflow block"; + TVM_FFI_ICHECK(seq_expr->blocks.size() == 1) << "now only support one dataflow block"; // TODO(chaofan, yixin): AD in non-dataflow block. - CHECK(seq_expr->blocks[0]->IsInstance()) + TVM_FFI_ICHECK(seq_expr->blocks[0]->IsInstance()) << "now only support one dataflow block"; // the return value should be a VarNode, and a scalar @@ -712,27 +714,29 @@ class GradientMutator : private ExprMutator { // Check that the target should be a Var of scalar tensor struct_info void CheckAndSetTarget(const Expr& e, int target_index) { if (auto* var = e.as()) { - CHECK_EQ(target_index, 0) << "When the function has only one return value, target_index can " - "only be 0. But the target_index specified is " - << target_index; + TVM_FFI_ICHECK_EQ(target_index, 0) + << "When the function has only one return value, target_index can " + "only be 0. But the target_index specified is " + << target_index; target_var_ = ffi::GetRef(var); } else if (auto* tuple = e.as()) { - CHECK(target_index >= 0 && target_index < static_cast(tuple->fields.size())) + TVM_FFI_ICHECK(target_index >= 0 && target_index < static_cast(tuple->fields.size())) << "target_index should be in the range of the number of return values of the " "function. " "But the specified target_index is " << target_index << ", while the number of return values is " << tuple->fields.size(); auto* var = tuple->fields[target_index].as(); - CHECK(var) << "Target must be a Var, but the specified target is " - << tuple->fields[target_index]; + TVM_FFI_ICHECK(var) << "Target must be a Var, but the specified target is " + << tuple->fields[target_index]; target_var_ = ffi::GetRef(var); } else { - LOG(FATAL) << "The return value of the function must be Var or Tuple. However, the return " - "value of the given function is " - << e; + TVM_FFI_THROW(InternalError) + << "The return value of the function must be Var or Tuple. However, the return " + "value of the given function is " + << e; } auto target_sinfo = GetStructInfo(target_var_); - CHECK(IsScalarTensor(target_sinfo) && IsFloatTensorSInfo(target_sinfo)) + TVM_FFI_ICHECK(IsScalarTensor(target_sinfo) && IsFloatTensorSInfo(target_sinfo)) << "The differentiation target must be a float scalar (0-dim Tensor), but the StructInfo " "of the given target " << target_var_ << " is " << GetStructInfo(target_var_); @@ -749,14 +753,14 @@ class GradientMutator : private ExprMutator { ffi::Array mapped_vars; for (const auto& var : require_grads) { auto it = var_map.find(var); - CHECK(it != var_map.end()) << "There is no Var named " << var->name_hint() - << " in the function " << func_name; - CHECK_EQ(var_set.count(var->vid), 0) + TVM_FFI_ICHECK(it != var_map.end()) + << "There is no Var named " << var->name_hint() << " in the function " << func_name; + TVM_FFI_ICHECK_EQ(var_set.count(var->vid), 0) << "Var " << var->name_hint() << " appears more than once"; var_set.emplace(var->vid); mapped_vars.push_back((*it).second); - CHECK(IsNestedTensorConditioned(GetStructInfo(var), IsFloatTensorSInfo)) + TVM_FFI_ICHECK(IsNestedTensorConditioned(GetStructInfo(var), IsFloatTensorSInfo)) << "Only Tensors of floating point dtype or Tuples of float " "Tensors can require gradients, but the StructInfo of Var " << var->name_hint() << " is " << GetStructInfo(var); diff --git a/src/relax/transform/gradient_simplifier.cc b/src/relax/transform/gradient_simplifier.cc index 5388e3706542..32eb6707dfe2 100644 --- a/src/relax/transform/gradient_simplifier.cc +++ b/src/relax/transform/gradient_simplifier.cc @@ -95,7 +95,7 @@ class GradientSimplifier : private ExprMutator { return ndim == 2; } auto axes = call_node->attrs.as()->axes.value(); - ICHECK(static_cast(axes.size()) == ndim); + TVM_FFI_ICHECK(static_cast(axes.size()) == ndim); for (int i = 0; i < ndim - 2; ++i) { if (axes[i] != i) { return false; @@ -107,7 +107,7 @@ class GradientSimplifier : private ExprMutator { // Return permute_dims(expr). Generate the axes needed. static Expr GetTransposeOf(const Expr& expr) { auto sinfo = MatchStructInfo(expr); - ICHECK(sinfo); + TVM_FFI_ICHECK(sinfo); auto ndim = sinfo.value()->ndim; if (ndim == 1) { return expr; diff --git a/src/relax/transform/infer_amp_utils.cc b/src/relax/transform/infer_amp_utils.cc index ac838d584821..cf64837d5bf4 100644 --- a/src/relax/transform/infer_amp_utils.cc +++ b/src/relax/transform/infer_amp_utils.cc @@ -25,7 +25,7 @@ namespace relax { NType NTypeFrom(const StructInfo& sinfo, DataType dtype) { auto fmapleaf = [&](const StructInfo& sinfo) -> NType { const auto* tensor = sinfo.as(); - ICHECK(tensor) << "Expected TensorStructInfo, but got " << sinfo; + TVM_FFI_ICHECK(tensor) << "Expected TensorStructInfo, but got " << sinfo; if (dtype == DataType::Void()) return NType(DLDataTypeToString(tensor->dtype)); else @@ -46,8 +46,8 @@ NType NTypeMerge(const NType& a, const NType& b) { DataType a = DataType(ffi::StringToDLDataType(a_str)); DataType b = DataType(ffi::StringToDLDataType(b_str)); - ICHECK_EQ(a.code(), b.code()); - ICHECK_EQ(a.lanes(), b.lanes()); + TVM_FFI_ICHECK_EQ(a.code(), b.code()); + TVM_FFI_ICHECK_EQ(a.lanes(), b.lanes()); return a.bits() > b.bits() ? a_str : b_str; }; return CombineNestedMsg(a, b, fcombine); diff --git a/src/relax/transform/infer_layout_utils.cc b/src/relax/transform/infer_layout_utils.cc index bc572f8a5407..b33beaa4e513 100644 --- a/src/relax/transform/infer_layout_utils.cc +++ b/src/relax/transform/infer_layout_utils.cc @@ -34,15 +34,17 @@ std::string TransposeSubLayoutStrLike(const std::string ref_str, const std::stri for (const char& c : desired_str) { if (std::isupper(c)) { auto res = src_str.find(c, 0); - ICHECK(res != std::string::npos) << "Invalid Layout:" - << "can't find " << c << " in source layout" << src_str; + TVM_FFI_ICHECK(res != std::string::npos) + << "Invalid Layout:" + << "can't find " << c << " in source layout" << src_str; out.push_back(ref_str[res]); } else if (isdigit(c)) { out.push_back(c); } else if (std::islower(c)) { auto res = src_str.find(std::toupper(c), 0); - ICHECK(res != std::string::npos) << "Invalid Layout:" - << "can't find " << c << " in source layout" << src_str; + TVM_FFI_ICHECK(res != std::string::npos) + << "Invalid Layout:" + << "can't find " << c << " in source layout" << src_str; out.push_back(std::tolower(ref_str[res])); } } @@ -58,7 +60,7 @@ Layout TransposeSubLayoutLike(const Layout& ref, const Layout& src, const Layout } Layout TransposeLike(const Layout& input, const Layout& src, const Layout& dst) { - ICHECK(src.ndim() == dst.ndim() && input.ndim() == src.ndim()) + TVM_FFI_ICHECK(src.ndim() == dst.ndim() && input.ndim() == src.ndim()) << "Layouts must have the same size"; std::vector axes; for (size_t i = 0; i < src.ndim(); ++i) { @@ -68,7 +70,7 @@ Layout TransposeLike(const Layout& input, const Layout& src, const Layout& dst) } ffi::String TransposeStrLike(const ffi::String& input, const Layout& src, const Layout& dst) { - ICHECK(src.ndim() == dst.ndim() && input.size() == src.ndim()) + TVM_FFI_ICHECK(src.ndim() == dst.ndim() && input.size() == src.ndim()) << "Layouts must have the same size"; std::string axes; for (size_t i = 0; i < src.ndim(); ++i) { @@ -87,7 +89,7 @@ int FindAxis(const Layout& dst, int axis) { } Layout InitialLayout(int ndim) { - ICHECK(ndim >= 0 && ndim <= 26) << "Only support up to 26 dimensions, but got " << ndim; + TVM_FFI_ICHECK(ndim >= 0 && ndim <= 26) << "Only support up to 26 dimensions, but got " << ndim; return Layout("ABCDEFGHIJKLMNOPQRSTUVWXYZ").SubLayout(0, ndim); } @@ -95,7 +97,7 @@ LayoutDecision InitialLayoutDecision(int ndim) { if (ndim == kUnknownNDim) { return LayoutDecision::InitUnknownDim(); } - ICHECK(ndim >= 0 && ndim <= 26) << "Only support up to 26 dimensions, but got " << ndim; + TVM_FFI_ICHECK(ndim >= 0 && ndim <= 26) << "Only support up to 26 dimensions, but got " << ndim; return Layout("ABCDEFGHIJKLMNOPQRSTUVWXYZ").SubLayout(0, ndim); } @@ -113,7 +115,7 @@ NLayout InitialNLayout(const Expr& expr) { return InitialNLayout(GetStructInfo(e LayoutDecision GetLayoutDecision(const VarLayoutMap& var_layout_map, const Expr& arg) { NLayout nlayout = GetNLayout(var_layout_map, arg); - ICHECK(nlayout.IsLeaf()) << "Cannot get layout for " << arg; + TVM_FFI_ICHECK(nlayout.IsLeaf()) << "Cannot get layout for " << arg; return nlayout.LeafValue(); } @@ -148,7 +150,8 @@ LayoutDecision FollowDecision(const LayoutDecision& src, int dst_ndim) { if (src_ndim == dst_ndim) { return src; } else { - ICHECK_LT(src_ndim, dst_ndim) << "Cannot broadcast from " << src_ndim << " to " << dst_ndim; + TVM_FFI_ICHECK_LT(src_ndim, dst_ndim) + << "Cannot broadcast from " << src_ndim << " to " << dst_ndim; std::string layout = InitialLayout(dst_ndim - src_ndim).name(); for (int i = 0; i < src_ndim; ++i) { layout.push_back(src->layout.name()[i] + dst_ndim - src_ndim); diff --git a/src/relax/transform/inline_functions.cc b/src/relax/transform/inline_functions.cc index f3f21cc7843d..ec45ba00f1b3 100644 --- a/src/relax/transform/inline_functions.cc +++ b/src/relax/transform/inline_functions.cc @@ -57,7 +57,7 @@ class FunctionInliner : public ExprMutator { auto gvar = opt.value(); if (auto opt = GetFunction(gvar)) { auto callee = opt.value(); - CHECK_EQ(callee->params.size(), node->args.size()) + TVM_FFI_ICHECK_EQ(callee->params.size(), node->args.size()) << "Attempted to inline call to " << gvar << ", which accepts " << callee->params.size() << " parameters. " << "However, it was called with " << node->args.size() << " arguments in expression " @@ -65,7 +65,7 @@ class FunctionInliner : public ExprMutator { Expr inlined = InlinedCall(callee, node->args); - CHECK(!inline_stack_.count(gvar)) + TVM_FFI_ICHECK(!inline_stack_.count(gvar)) << "Relax function inlining does not support recursive functions. " << "However, recursive function " << gvar << " was requested to be inlined."; @@ -154,8 +154,7 @@ Function FunctionInlineFunctions( Function func, const ffi::Map, Function>& replacements) { for (const auto& [key, func] : replacements) { if (auto ptr = key.as()) { - CHECK(!replacements.count(ptr->name_hint)) - << "ValueError: " + TVM_FFI_CHECK(!replacements.count(ptr->name_hint), ValueError) << "Map of functions to inline must be unambiguous. " << "However, the map provided contains both the GlobalVar " << key << " and the string \'" << ptr->name_hint << "'"; diff --git a/src/relax/transform/kill_after_last_use.cc b/src/relax/transform/kill_after_last_use.cc index e1e8a5d87998..bae9794ecc22 100644 --- a/src/relax/transform/kill_after_last_use.cc +++ b/src/relax/transform/kill_after_last_use.cc @@ -165,12 +165,12 @@ class CollectLastUsage : public ExprVisitor { storage_objects_.insert(binding->var.get()); } else if (val->op.same_as(mem_kill_tensor) || val->op.same_as(mem_kill_storage) || val->op.same_as(vm_kill_object)) { - CHECK_EQ(val->args.size(), 1) + TVM_FFI_ICHECK_EQ(val->args.size(), 1) << "Operator " << val->op << " should have one argument, " << "but instead found " << val->args.size() << " arguments: " << val->args; auto killed_object = val->args[0].as(); - ICHECK(killed_object) << "Internal error: non-normalized expression " - << ffi::GetRef(val); + TVM_FFI_ICHECK(killed_object) + << "Internal error: non-normalized expression " << ffi::GetRef(val); killed_objects_.insert(killed_object); } else { // Only recursively visit if it isn't one of the special cases. diff --git a/src/relax/transform/lambda_lift.cc b/src/relax/transform/lambda_lift.cc index e77b0a266038..1090669ea31b 100644 --- a/src/relax/transform/lambda_lift.cc +++ b/src/relax/transform/lambda_lift.cc @@ -70,7 +70,7 @@ class LambdaNameCollector : ExprVisitor { // model definition, they are intentionally verbose to // (hopefully) provide sufficient context to a user encountering // the error. - CHECK(!previous_global_vars_.count(public_name)) + TVM_FFI_ICHECK(!previous_global_vars_.count(public_name)) << "Function " << name_stack_.front() << " contains a lambda with kGlobalSymbol (\"" << tvm::attr::kGlobalSymbol << "\" attribute of \"" << public_name << "\". " << "However, the module already contains a GlobalVar with this name. " @@ -80,7 +80,7 @@ class LambdaNameCollector : ExprVisitor { << " would require violating one of these two conditions."; auto it = new_public_names_.find(public_name); - CHECK(it == new_public_names_.end()) + TVM_FFI_ICHECK(it == new_public_names_.end()) << "Function " << name_stack_.front() << " contains a lambda with kGlobalSymbol (\"" << tvm::attr::kGlobalSymbol << "\" attribute of \"" << public_name << "\". " << "However, the function " << it->second.front() @@ -213,7 +213,7 @@ class LambdaNameCollector : ExprVisitor { return stream.str(); }); - ICHECK(remaining_to_name.empty()) + TVM_FFI_ICHECK(remaining_to_name.empty()) << "Fallback failed to make unique names for all lifted lambda functions"; return lifted_names; @@ -263,8 +263,7 @@ class LambdaLifter : public ExprMutator { ffi::String lift_func_name = [&]() { auto it = lifted_names_.find(func_node); - ICHECK(it != lifted_names_.end()) - << "InternalError: " + TVM_FFI_CHECK(it != lifted_names_.end(), InternalError) << "Found lambda function during mutation step, " << "but it wasn't found during the earlier name-generation step."; return it->second; @@ -333,7 +332,7 @@ class LambdaLifter : public ExprMutator { Function(lifted_func_params, body, ret_struct_info, func_node->is_pure, func_node->attrs); } - ICHECK(lifted_func.defined()); + TVM_FFI_ICHECK(lifted_func.defined()); if (is_closure || IsClosure(lifted_func)) { closures_.insert(gvar_lifted_func); @@ -378,11 +377,12 @@ class LambdaLifter : public ExprMutator { orig_call->op->struct_info_.as()) { return func_sinfo->purity; } else { - LOG(FATAL) << "Could not determine purity of call to " << orig_call->op - << ", as it is neither a tvm::Op (type = \"" << orig_call->op->GetTypeKey() - << "\"), " - << "nor is is annotated with FuncStructInfo (sinfo = " - << orig_call->op->struct_info_ << ")"; + TVM_FFI_THROW(InternalError) + << "Could not determine purity of call to " << orig_call->op + << ", as it is neither a tvm::Op (type = \"" << orig_call->op->GetTypeKey() + << "\"), " + << "nor is is annotated with FuncStructInfo (sinfo = " + << orig_call->op->struct_info_ << ")"; } }(); @@ -444,7 +444,7 @@ class LambdaLifter : public ExprMutator { return true; } IRModule ctx_mod = builder_->GetContextIRModule(); - ICHECK(ctx_mod->functions.size() > 0); + TVM_FFI_ICHECK(ctx_mod->functions.size() > 0); BaseFunc func = ctx_mod->Lookup(ffi::GetRef(global_var)); const auto* func_node = func.as(); if (func_node) { diff --git a/src/relax/transform/lazy_transform_params.cc b/src/relax/transform/lazy_transform_params.cc index bc6f4530db59..528978da5406 100644 --- a/src/relax/transform/lazy_transform_params.cc +++ b/src/relax/transform/lazy_transform_params.cc @@ -37,11 +37,10 @@ namespace { std::optional GetNumInputParams(const FunctionNode* func) { if (auto opt_int_imm = func->GetAttr(attr::kNumInput)) { int64_t num_input_params = opt_int_imm.value()->value; - CHECK_GE(num_input_params, 0) << "ValueError: " - << "Annotation for attr::kNumInput (\"" << attr::kNumInput - << "\") must be non-negative, but was " << num_input_params; - CHECK_LE(static_cast(num_input_params), func->params.size()) - << "ValueError: " + TVM_FFI_CHECK_GE(num_input_params, 0, ValueError) + << "Annotation for attr::kNumInput (\"" << attr::kNumInput + << "\") must be non-negative, but was " << num_input_params; + TVM_FFI_CHECK_LE(static_cast(num_input_params), func->params.size(), ValueError) << "Annotation for attr::kNumInput (\"" << attr::kNumInput << "\") specifies " << num_input_params << " parameters to be provided at runtime, " << "but the function only accepts " << func->params.size() << " parameters in total"; diff --git a/src/relax/transform/lift_transform_params.cc b/src/relax/transform/lift_transform_params.cc index f7c49d0da8df..15ba2b82e8b4 100644 --- a/src/relax/transform/lift_transform_params.cc +++ b/src/relax/transform/lift_transform_params.cc @@ -212,7 +212,7 @@ struct LocalCollectInfo : public BaseCollectInfo { } Function MakeCompileTimeFunction() const { - ICHECK(!global_info); // This function is only called for local lifting + TVM_FFI_ICHECK(!global_info); // This function is only called for local lifting return MakeCompileTimeFunctionHelper(GetCompileTimeInputs(), computable_at_compile_time, GetPropagatedSymbolicVariables(), GetCompileTimeOutputs()); } @@ -517,7 +517,7 @@ class ParamRemapper : private ExprFunctor { int num_params = static_cast(functions[0]->params.size()) - num_inputs_0; for (int i = 0; i < static_cast(functions.size()); i++) { auto num_inputs_i = functions[i]->GetAttr(attr::kNumInput).value()->value; - CHECK_EQ(num_params, static_cast(functions[i]->params.size()) - num_inputs_i) + TVM_FFI_ICHECK_EQ(num_params, static_cast(functions[i]->params.size()) - num_inputs_i) << "The number of parameters should be the same for all target functions"; for (int j = 0; j < num_params; j++) { @@ -538,19 +538,19 @@ class ParamRemapper : private ExprFunctor { void VisitExpr_(const VarNode* lhs_var, const Expr& rhs_expr) final { auto rhs_var = Downcast(rhs_expr); if (auto it = var_remap_.find(ffi::GetRef(lhs_var)); it != var_remap_.end()) { - CHECK((*it).second.same_as(rhs_var)); + TVM_FFI_ICHECK((*it).second.same_as(rhs_var)); } else { var_remap_.Set(ffi::GetRef(lhs_var), rhs_var); } - CHECK(tvm::ffi::StructuralEqual::Equal(lhs_var->struct_info_, rhs_var->struct_info_, - /*map_free_vars=*/true)) + TVM_FFI_ICHECK(tvm::ffi::StructuralEqual::Equal(lhs_var->struct_info_, rhs_var->struct_info_, + /*map_free_vars=*/true)) << "The struct info of the parameters should be the same for all target functions"; auto lhs_tir_vars = DefinableTIRVarsInStructInfo(GetStructInfo(ffi::GetRef(lhs_var))); auto rhs_tir_vars = DefinableTIRVarsInStructInfo(GetStructInfo(rhs_expr)); - ICHECK_EQ(lhs_tir_vars.size(), rhs_tir_vars.size()); + TVM_FFI_ICHECK_EQ(lhs_tir_vars.size(), rhs_tir_vars.size()); for (size_t i = 0; i < lhs_tir_vars.size(); i++) { if (auto it = tir_var_remap_.find(lhs_tir_vars[i]); it != tir_var_remap_.end()) { - CHECK((*it).second.same_as(rhs_tir_vars[i])); + TVM_FFI_ICHECK((*it).second.same_as(rhs_tir_vars[i])); } else { tir_var_remap_.Set(lhs_tir_vars[i], rhs_tir_vars[i]); } @@ -567,7 +567,7 @@ class GlobalLiftableBindingCollector : public BaseLiftableBindingCollector { const ffi::Map& var_remap, const ffi::Map& tir_var_remap) { GlobalLiftableBindingCollector collector(var_remap, tir_var_remap); - ICHECK(functions.size()); + TVM_FFI_ICHECK(functions.size()); for (const auto& func : functions) { int num_inputs = func->GetAttr(attr::kNumInput).value()->value; for (int i = num_inputs; i < static_cast(func->params.size()); i++) { @@ -616,7 +616,8 @@ class GlobalLiftableBindingCollector : public BaseLiftableBindingCollector { const ffi::Map tir_var_remap) : var_remap_(var_remap), tir_var_remap_(tir_var_remap) {} void VisitBinding(const Binding& binding) override { - CHECK(!binding->IsInstance()) << "MatchCast is not supported in global lifting"; + TVM_FFI_ICHECK(!binding->IsInstance()) + << "MatchCast is not supported in global lifting"; if (CanLiftBinding(binding)) { liftable_vars_.insert(binding->var); auto bound_value = GetBoundValue(binding); @@ -687,11 +688,11 @@ class ConsumeBundledParams : public ExprMutator { Expr VisitExpr_(const FunctionNode* func) final { auto opt_num_input = func->GetAttr(attr::kNumInput); - ICHECK(opt_num_input.defined()); + TVM_FFI_ICHECK(opt_num_input.defined()); auto num_input = opt_num_input.value()->value; - ICHECK_EQ(func->params.size(), num_input + 1); + TVM_FFI_ICHECK_EQ(func->params.size(), num_input + 1); params_ = func->params.back(); - ICHECK(params_->struct_info_.as()); + TVM_FFI_ICHECK(params_->struct_info_.as()); return ExprMutator::VisitExpr_(func); } @@ -707,22 +708,23 @@ std::vector> GetTargetFunctions( auto names = shared_transform.as>().value(); for (const auto& name : names) { auto gvar = mod->global_var_map_.Get(name); - CHECK(gvar) << "When LiftTransformParams is called with a list of function names, " - << "all function names must occur within the IRModule. " - << "However, the IRModule does not contain a function names '" << name << "'"; + TVM_FFI_ICHECK(gvar) << "When LiftTransformParams is called with a list of function names, " + << "all function names must occur within the IRModule. " + << "However, the IRModule does not contain a function names '" << name + << "'"; auto base_func = mod->functions.Get(gvar.value()); - ICHECK(base_func.has_value()) + TVM_FFI_ICHECK(base_func.has_value()) << "Ill-formed IRModule. " << "The map from name to GlobalVar found " << gvar.value() << " for the function name '" << name << "', but this GlobalVar does not appear in the IRModule"; auto func = base_func.value().as(); - CHECK(func) << "When LiftTransformParams is called with a list of function names, " - << "only functions in the list must be relax functions. " - << "However, the function " << name << " is of type " - << base_func.value()->GetTypeKey(); - CHECK(func.value()->GetAttr(attr::kNumInput)) + TVM_FFI_ICHECK(func) << "When LiftTransformParams is called with a list of function names, " + << "only functions in the list must be relax functions. " + << "However, the function " << name << " is of type " + << base_func.value()->GetTypeKey(); + TVM_FFI_ICHECK(func.value()->GetAttr(attr::kNumInput)) << "When LiftTransformParams is called with a list of function names, " << "all functions in the list must have the kNumInput ('" << attr::kNumInput << "') attribute. " @@ -757,7 +759,7 @@ Pass PartitionTransformParams(ffi::Variant> shared auto pass_func = [=](IRModule mod, PassContext pc) { std::optional global_collect_info; - CHECK((shared_transform.as() || shared_transform.as>())) + TVM_FFI_ICHECK((shared_transform.as() || shared_transform.as>())) << "shared_transform should be a boolean or an array of function names"; auto target_functions = GetTargetFunctions(mod, shared_transform); diff --git a/src/relax/transform/lower_alloc_tensor.cc b/src/relax/transform/lower_alloc_tensor.cc index d63cd15744bd..d1eb58125a37 100644 --- a/src/relax/transform/lower_alloc_tensor.cc +++ b/src/relax/transform/lower_alloc_tensor.cc @@ -41,9 +41,10 @@ class Mutator : public ExprMutator { static const Op& mem_alloc_tensor_op = Op::Get("relax.memory.alloc_tensor"); if (op->op.same_as(alloc_tensor_op)) { - CHECK_EQ(op->args.size(), 4) << "Op " << op->op << " should have three arguments, " - << "[shape, dtype, runtime_device_index, storage_scope]. " - << "However, received " << ffi::GetRef(op); + TVM_FFI_ICHECK_EQ(op->args.size(), 4) + << "Op " << op->op << " should have three arguments, " + << "[shape, dtype, runtime_device_index, storage_scope]. " + << "However, received " << ffi::GetRef(op); auto shape_arg = op->args[0]; auto dtype = Downcast(op->args[1]); @@ -62,9 +63,10 @@ class Mutator : public ExprMutator { } } - LOG(FATAL) << "Shape argument for " << alloc_tensor_op << " should be a ShapeExpr, " - << "or a variable that holds a ShapeExpr. " - << "However, received argument " << shape_arg << " with struct info " << sinfo; + TVM_FFI_THROW(InternalError) + << "Shape argument for " << alloc_tensor_op << " should be a ShapeExpr, " + << "or a variable that holds a ShapeExpr. " + << "However, received argument " << shape_arg << " with struct info " << sinfo; TVM_FFI_UNREACHABLE(); }(); diff --git a/src/relax/transform/merge_composite_functions.cc b/src/relax/transform/merge_composite_functions.cc index e8a9b74d94c4..86c5b83a8aab 100644 --- a/src/relax/transform/merge_composite_functions.cc +++ b/src/relax/transform/merge_composite_functions.cc @@ -108,7 +108,7 @@ class CompositeGroupsBuilder : public MemoizedExprTranslator { if (const auto* node = binding.as()) { return VisitBinding_(node); } else { - LOG(FATAL) << "TypeError: Invalid type: " << binding->GetTypeKey(); + TVM_FFI_THROW(TypeError) << "Invalid type: " << binding->GetTypeKey(); } } @@ -130,7 +130,7 @@ class CompositeGroupsBuilder : public MemoizedExprTranslator { } else if (const auto* node = block.as()) { VisitBindingBlock_(node); } else { - LOG(FATAL) << "TypeError: Invalid type: " << block->GetTypeKey(); + TVM_FFI_THROW(TypeError) << "Invalid type: " << block->GetTypeKey(); } } @@ -197,7 +197,7 @@ class CompositeGroupsBuilder : public MemoizedExprTranslator { } void MergeGroup(Group* from, Group* to) { - ICHECK_EQ(GetCodegenName(from), GetCodegenName(to)); + TVM_FFI_ICHECK_EQ(GetCodegenName(from), GetCodegenName(to)); Group* from_root = from->FindRoot(); Group* to_root = to->FindRoot(); @@ -245,8 +245,8 @@ class CompositeGroupsBuilder : public MemoizedExprTranslator { return; } - ICHECK(memo_.count(expr)) << "Could not find memo-ized group for expression of type " - << expr->GetTypeKey(); + TVM_FFI_ICHECK(memo_.count(expr)) + << "Could not find memo-ized group for expression of type " << expr->GetTypeKey(); auto arg_group_root = memo_[expr]->FindRoot(); if (arg_group_root == group_root) { diff --git a/src/relax/transform/meta_schedule.cc b/src/relax/transform/meta_schedule.cc index 8bd5c3f4b68a..06ac38a77dda 100644 --- a/src/relax/transform/meta_schedule.cc +++ b/src/relax/transform/meta_schedule.cc @@ -79,14 +79,14 @@ Pass MetaScheduleApplyDatabase(ffi::Optional work_dir, bool enable_ Target target = Target::Current(false); const std::optional normalize_mod_func_ = tvm::ffi::Function::GetGlobalRequired("tvm.s_tir.meta_schedule.normalize_mod"); - ICHECK(normalize_mod_func_.has_value()) << "Normalization function is not found."; + TVM_FFI_ICHECK(normalize_mod_func_.has_value()) << "Normalization function is not found."; auto pass_func = [=](IRModule mod, PassContext ctx) { Database database{ffi::UnsafeInit()}; if (Database::Current().defined()) { database = Database::Current().value(); } else { - ICHECK(work_dir.has_value()); + TVM_FFI_ICHECK(work_dir.has_value()); std::filesystem::create_directories(work_dir.value().c_str()); ffi::String path_workload = work_dir.value() + "/database_workload.json"; ffi::String path_tuning_record = work_dir.value() + "/database_tuning_record.json"; @@ -124,9 +124,9 @@ Pass MetaScheduleApplyDatabase(ffi::Optional work_dir, bool enable_ record->trace->ApplyToSchedule(sch, /*remove_postproc=*/false); } IRModule new_mod = sch->mod(); - ICHECK_EQ(new_mod->functions.size(), 1); + TVM_FFI_ICHECK_EQ(new_mod->functions.size(), 1); BaseFunc new_base_func = (*new_mod->functions.begin()).second; - ICHECK(new_base_func->IsInstance()); + TVM_FFI_ICHECK(new_base_func->IsInstance()); tir::PrimFunc tuned_prim_func = Downcast(new_base_func); // maintain the original attributes tir::PrimFunc new_prim_func = tir::PrimFunc(/*params=*/tuned_prim_func->params, diff --git a/src/relax/transform/normalize.cc b/src/relax/transform/normalize.cc index e764e333f721..43fea7897f24 100644 --- a/src/relax/transform/normalize.cc +++ b/src/relax/transform/normalize.cc @@ -113,7 +113,7 @@ class NormalizeMutator : public ExprMutatorBase { } else if (const auto* node = block.as()) { ret = VisitBindingBlock_(node); } else { - LOG(FATAL) << "TypeError: Invalid type: " << block->GetTypeKey(); + TVM_FFI_THROW(TypeError) << "Invalid type: " << block->GetTypeKey(); } return ret; } @@ -140,7 +140,7 @@ class NormalizeMutator : public ExprMutatorBase { } else if (const auto* node = binding.as()) { VisitBinding_(node); } else { - LOG(FATAL) << "TypeError: Invalid type: " << binding->GetTypeKey(); + TVM_FFI_THROW(TypeError) << "Invalid type: " << binding->GetTypeKey(); } } @@ -236,7 +236,7 @@ class GlobalVarNormalizer : private ExprMutator { } auto global_symbol_value = global_symbol.value(); - CHECK(!name_supply_->ContainsName(global_symbol_value)) + TVM_FFI_ICHECK(!name_supply_->ContainsName(global_symbol_value)) << "IRModule contains duplicate global symbol: " << global_symbol_value; name_supply_->ReserveName(global_symbol_value); auto new_gvar = builder_->AddFunction(func, global_symbol_value); @@ -262,7 +262,7 @@ class GlobalVarNormalizer : private ExprMutator { } Expr VisitExpr_(const GlobalVarNode* op) final { - ICHECK(gvar_map_.count(ffi::GetRef(op))); + TVM_FFI_ICHECK(gvar_map_.count(ffi::GetRef(op))); return gvar_map_[ffi::GetRef(op)]; } diff --git a/src/relax/transform/realize_vdevice.cc b/src/relax/transform/realize_vdevice.cc index 7f1042d57ecc..e2af04edd54f 100644 --- a/src/relax/transform/realize_vdevice.cc +++ b/src/relax/transform/realize_vdevice.cc @@ -42,9 +42,9 @@ class VDeviceLookup { if (auto vdevice = info.as()) { return vdevice.value(); } else { - LOG(FATAL) << "TypeError: " - << "Each item in an IRModule's \"vdevice\" annotation must be a VDevice, " - << "but instead found item of type " << info->GetTypeKey(); + TVM_FFI_THROW(TypeError) + << "Each item in an IRModule's \"vdevice\" annotation must be a VDevice, " + << "but instead found item of type " << info->GetTypeKey(); } }; @@ -53,17 +53,17 @@ class VDeviceLookup { VDevice operator()(Attrs hint_on_device_attrs) { auto attrs = hint_on_device_attrs.as(); - ICHECK(attrs); + TVM_FFI_ICHECK(attrs); int32_t device_type = attrs->device_type; int32_t device_id = attrs->index; ffi::String memory_scope = attrs->memory_scope; - CHECK(opt_vdevices_.defined()) - << "ValueError: The target VDevice in the GlobalInfos was not found."; + TVM_FFI_CHECK(opt_vdevices_.defined(), ValueError) + << "The target VDevice in the GlobalInfos was not found."; auto vdevices = opt_vdevices_.value(); - CHECK_GE(device_id, 0) << "ValueError: " - << "The device id in R.hint_on_device must not be negative"; + TVM_FFI_CHECK_GE(device_id, 0, ValueError) + << "The device id in R.hint_on_device must not be negative"; for (auto vdevice : vdevices) { int dev_type = vdevice->target->GetTargetDeviceType(); @@ -72,9 +72,9 @@ class VDeviceLookup { return vdevice; } } - LOG(FATAL) << "ValueError: " - << "Expected to find device with type " << device_id << " and id " << device_id - << ", but no such device was found in the IRModule's \"vdevice\" annotation"; + TVM_FFI_THROW(ValueError) + << "Expected to find device with type " << device_id << " and id " << device_id + << ", but no such device was found in the IRModule's \"vdevice\" annotation"; TVM_FFI_UNREACHABLE(); } @@ -139,8 +139,7 @@ class DeviceHintCollector : ExprVisitor { // output, or may return the result of a `relax::Call` that // produces a tuple of outputs. if (auto tuple = expr.as()) { - CHECK_EQ(tuple_info->fields.size(), tuple->fields.size()) - << "ValueError: " + TVM_FFI_CHECK_EQ(tuple_info->fields.size(), tuple->fields.size(), ValueError) << "Function returns a tuple with " << tuple->fields.size() << " elements, " << "but is annotated as returning a tuple with " << tuple_info->fields.size() << " elements"; @@ -173,7 +172,7 @@ class DeviceHintCollector : ExprVisitor { auto vdevice = vdevice_lookup_(call->attrs); known_vdevice_.Set(binding->var, vdevice); - ICHECK_EQ(call->args.size(), 1); + TVM_FFI_ICHECK_EQ(call->args.size(), 1); if (auto arg_var = call->args[0].as()) { hint_on_device_inputs_.Set(arg_var.value(), vdevice); } @@ -384,7 +383,7 @@ class VDeviceStructInfoUpdater : ExprMutator { return call; } - ICHECK_EQ(call->args.size(), 1); + TVM_FFI_ICHECK_EQ(call->args.size(), 1); auto arg = call->args[0]; auto input_vdevice = Downcast(arg->struct_info_)->vdevice; auto output_vdevice = vdevice_lookup_(call->attrs); diff --git a/src/relax/transform/remove_unused_outputs.cc b/src/relax/transform/remove_unused_outputs.cc index 83170abd635b..9de26d8b1a4e 100644 --- a/src/relax/transform/remove_unused_outputs.cc +++ b/src/relax/transform/remove_unused_outputs.cc @@ -96,14 +96,12 @@ class PartialTupleUsageCollector : ExprVisitor { if (auto* usage_mask_ptr = GetCalleeUsageMask(op->tuple)) { auto& used_indices = *usage_mask_ptr; - CHECK_GE(op->index, 0) << "IndexError: " - << "Indices for TupleGetItem must be non-negative, " - << "but expression " << ffi::GetRef(op) - << " uses a tuple index of " << op->index; + TVM_FFI_CHECK_GE(op->index, 0, IndexError) + << "Indices for TupleGetItem must be non-negative, " + << "but expression " << ffi::GetRef(op) << " uses a tuple index of " << op->index; size_t index = op->index; - CHECK_LT(index, used_indices.size()) - << "IndexError: " + TVM_FFI_CHECK_LT(index, used_indices.size(), IndexError) << "Indices for TupleGetItem must be less than the size of the tuple, " << "but expression " << ffi::GetRef(op) << " uses a tuple index of " << op->index << " for a tuple of size " << used_indices.size(); @@ -161,8 +159,8 @@ Function UpdateCallee(Function func, const std::vector& usage_mask) { auto old_func_sinfo = func->struct_info_.as(); auto old_ret_sinfo = func->ret_struct_info.as(); - ICHECK(old_ret_sinfo) << "All functions returning non-tuple outputs " - << "should have been pruned already by PartialTupleUsageCollector"; + TVM_FFI_ICHECK(old_ret_sinfo) << "All functions returning non-tuple outputs " + << "should have been pruned already by PartialTupleUsageCollector"; ffi::Array outputs; @@ -246,16 +244,15 @@ Pass RemoveUnusedOutputs() { new_callees->Add(new_gvar, new_func); callsite_updaters[gvar] = [old_gvar = gvar, new_gvar, usage_mask](Call call) -> Expr { - ICHECK(call->op.same_as(old_gvar)) << "InternalError: " - << "Updater should be applied to " << old_gvar - << ", but was applied to " << call->op; + TVM_FFI_CHECK(call->op.same_as(old_gvar), InternalError) + << "Updater should be applied to " << old_gvar << ", but was applied to " + << call->op; auto old_call_sinfo = call->struct_info_.as(); - ICHECK(old_call_sinfo) - << "InternalError: " + TVM_FFI_CHECK(old_call_sinfo, InternalError) << "Updater should be applied to Call producing an output tuple, " << "but " << call << " has struct info " << call->struct_info_; - CHECK_EQ(usage_mask.size(), old_call_sinfo->fields.size()) + TVM_FFI_ICHECK_EQ(usage_mask.size(), old_call_sinfo->fields.size()) << "Function " << call->op << " produces " << usage_mask.size() << " outputs, " << "but " << call << " was used in a context expecting " << old_call_sinfo->fields.size() << " outputs."; diff --git a/src/relax/transform/remove_unused_parameters.cc b/src/relax/transform/remove_unused_parameters.cc index 5003dec8a8d2..b59cb61d2853 100644 --- a/src/relax/transform/remove_unused_parameters.cc +++ b/src/relax/transform/remove_unused_parameters.cc @@ -111,7 +111,7 @@ std::optional AnalyzeCallee(Function func) { auto arg_updater = [parameter_mask, old_relax_params = func->params, free_tir_vars](ffi::Array old_args) -> ffi::Array { - ICHECK_EQ(old_args.size(), parameter_mask.size()) + TVM_FFI_ICHECK_EQ(old_args.size(), parameter_mask.size()) << "Call provides " << old_args.size() << ", but the callee accepts " << parameter_mask.size() << " parameters"; @@ -201,9 +201,9 @@ Pass RemoveUnusedParameters() { callsite_updaters[gvar] = [old_gvar = gvar, new_gvar, arg_updater = callee_res->arg_updater](Call call) -> Call { - ICHECK(call->op.same_as(old_gvar)) << "InternalError: " - << "Updater should be applied to " << old_gvar - << ", but was applied to " << call->op; + TVM_FFI_CHECK(call->op.same_as(old_gvar), InternalError) + << "Updater should be applied to " << old_gvar << ", but was applied to " + << call->op; auto write_ptr = call.CopyOnWrite(); write_ptr->op = new_gvar; write_ptr->args = arg_updater(call->args); diff --git a/src/relax/transform/reorder_permute_dims_after_concat.cc b/src/relax/transform/reorder_permute_dims_after_concat.cc index 73bc1853816e..bc542ccf91ef 100644 --- a/src/relax/transform/reorder_permute_dims_after_concat.cc +++ b/src/relax/transform/reorder_permute_dims_after_concat.cc @@ -71,7 +71,7 @@ std::tuple)>> } auto make_pattern_with_num_concat = [&](size_t num_concat) -> DFPattern { - ICHECK_LT(num_concat, pat_permute_dims.size()); + TVM_FFI_ICHECK_LT(num_concat, pat_permute_dims.size()); auto concat_tuple = TuplePattern( ffi::Array(pat_permute_dims.begin(), pat_permute_dims.begin() + num_concat)); return IsOp("relax.concat")(concat_tuple); @@ -84,9 +84,9 @@ std::tuple)>> auto get_permute_dims_optional_axes = [](const Expr& expr) -> ffi::Optional> { auto call = expr.as(); - ICHECK(call); + TVM_FFI_ICHECK(call); auto attrs = call->attrs.as(); - ICHECK(attrs); + TVM_FFI_ICHECK(attrs); return attrs->axes; }; @@ -99,10 +99,10 @@ std::tuple)>> auto call = Downcast(expr); ffi::Array permutation; auto arg_sinfo = call->args[0]->struct_info_.as(); - CHECK(arg_sinfo) << "Expected permute_dims to have a single tensor argument, " - << "but argument " << call->args[0] << " has struct info " - << call->args[0]->struct_info_; - CHECK_GE(arg_sinfo->ndim, 0); + TVM_FFI_ICHECK(arg_sinfo) << "Expected permute_dims to have a single tensor argument, " + << "but argument " << call->args[0] << " has struct info " + << call->args[0]->struct_info_; + TVM_FFI_ICHECK_GE(arg_sinfo->ndim, 0); size_t ndim = arg_sinfo->ndim; for (size_t i = 0; i < ndim; i++) { permutation.push_back(Integer(ndim - i - 1)); @@ -137,8 +137,7 @@ std::tuple)>> } } - ICHECK_GE(all_permute_dims.size(), min_concat) - << "InternalError: " + TVM_FFI_CHECK_GE(all_permute_dims.size(), min_concat, InternalError) << "Pattern match should return at least " << min_concat << " items, but only found " << all_permute_dims.size() << ": " << all_permute_dims; @@ -150,7 +149,7 @@ std::tuple)>> Call concat_call = Downcast(matches[pat_concat]); auto concat_attrs = concat_call->attrs.as(); - ICHECK(concat_attrs); + TVM_FFI_ICHECK(concat_attrs); auto old_concat_axis = [&]() -> size_t { return concat_attrs->axis.value_or(0); }(); Integer new_concat_axis = get_permute_dims_axes(all_permute_dims[0])[old_concat_axis]; diff --git a/src/relax/transform/reorder_take_after_matmul.cc b/src/relax/transform/reorder_take_after_matmul.cc index 25f245101b1b..260f1a5525f5 100644 --- a/src/relax/transform/reorder_take_after_matmul.cc +++ b/src/relax/transform/reorder_take_after_matmul.cc @@ -56,15 +56,13 @@ std::tuple)>> auto indices = matches[pat_indices]; const auto* take_call = matches[pat_rhs].as(); - ICHECK(take_call) << "InternalError: " - << "Match of relax.take operator should produce Call, " - << "but instead produces " << matches[pat_rhs] << " with type " - << matches[pat_rhs]->GetTypeKey(); + TVM_FFI_CHECK(take_call, InternalError) << "Match of relax.take operator should produce Call, " + << "but instead produces " << matches[pat_rhs] + << " with type " << matches[pat_rhs]->GetTypeKey(); const auto* attrs = take_call->attrs.as(); - ICHECK(attrs) << "InternalError: " - << "Attributes for relax.take operator should be TakeAttrs, " - << "but were instead " << take_call->attrs << " with type " - << take_call->GetTypeKey(); + TVM_FFI_CHECK(attrs, InternalError) + << "Attributes for relax.take operator should be TakeAttrs, " + << "but were instead " << take_call->attrs << " with type " << take_call->GetTypeKey(); const auto* lhs_sinfo = lhs->struct_info_.as(); if (!lhs_sinfo) return expr; diff --git a/src/relax/transform/rewrite_cuda_graph.cc b/src/relax/transform/rewrite_cuda_graph.cc index 9749599c2f85..4fa00d9612ad 100644 --- a/src/relax/transform/rewrite_cuda_graph.cc +++ b/src/relax/transform/rewrite_cuda_graph.cc @@ -273,7 +273,7 @@ class CUDAGraphRewritePlanner : public ExprVisitor { auto* plan = arena_->make(); plan->is_alloc = true; plan->func = region->Build(); - ICHECK(region->size()); + TVM_FFI_ICHECK(region->size()); plan->launch_point = region->bindings_.front()->var.get(); plan->is_alloc = is_alloc; plan->lifted_bindings = std::move(region->bindings_); @@ -471,7 +471,7 @@ class CUDAGraphRewritePlanner : public ExprVisitor { void VisitBinding_(const VarBindingNode* binding, const TupleGetItemNode* tuple_get_item) final { const VarNode* tuple = tuple_get_item->tuple.as(); - ICHECK(tuple); + TVM_FFI_ICHECK(tuple); if (IsStatic(tuple_get_item->tuple)) { AddStaticBinding(binding, false); MarkAsFuncInput({tuple}); @@ -666,17 +666,17 @@ Function MergeAllocationPlans(const std::vector& all // for each original function. for (int plan_id = 0; plan_id < static_cast(alloc_plans.size()); ++plan_id) { LiftedFunctionRewritePlan* plan = alloc_plans[plan_id]; - ICHECK(plan->is_alloc); + TVM_FFI_ICHECK(plan->is_alloc); for (const VarBindingNode* binding : plan->lifted_bindings) { // Extract the stroage record from the Call expr. Call alloc_storage = Downcast(binding->value); - ICHECK(alloc_storage->op.same_as(mem_alloc_storage_op)); + TVM_FFI_ICHECK(alloc_storage->op.same_as(mem_alloc_storage_op)); auto storage_shape = Downcast(alloc_storage->args[0]); - ICHECK_EQ(storage_shape->values.size(), 1); + TVM_FFI_ICHECK_EQ(storage_shape->values.size(), 1); int64_t size = Downcast(storage_shape->values[0])->value; int64_t virtual_device_id = Downcast(Downcast(alloc_storage->args[1])->value)->value; - ICHECK_EQ(virtual_device_id, 0); + TVM_FFI_ICHECK_EQ(virtual_device_id, 0); ffi::String storage_scope = Downcast(alloc_storage->args[2])->value; auto [it, _] = storage_records.try_emplace(storage_scope, alloc_plans.size()); it->second[plan_id].emplace_back(StorageRecord{size, binding, plan}); @@ -780,8 +780,8 @@ class CUDAGraphRewriter : public ExprMutator { Expr launch_subgraph; if (plan->is_alloc) { // Storage allocation should be fully static and shouldn't depend on any symbolic variables. - ICHECK(!plan->propogated_tir_vars.defined()); - ICHECK(plan->inputs.empty()); + TVM_FFI_ICHECK(!plan->propogated_tir_vars.defined()); + TVM_FFI_ICHECK(plan->inputs.empty()); auto gv_alloc = gv_global_alloc_.value(); auto ret_struct_info = Downcast(gv_alloc->struct_info_.value())->ret; launch_subgraph = Call( @@ -806,7 +806,7 @@ class CUDAGraphRewriter : public ExprMutator { auto symbolic_params = Downcast(shape_expr->struct_info_.value())->values.value(); ffi::Map tir_var_remap; - ICHECK_EQ(symbolic_params.size(), propogated_tir_vars->values.size()); + TVM_FFI_ICHECK_EQ(symbolic_params.size(), propogated_tir_vars->values.size()); for (int i = 0; i < static_cast(symbolic_params.size()); ++i) { tir_var_remap.Set(Downcast(symbolic_params[i]), propogated_tir_vars->values[i]); } diff --git a/src/relax/transform/rewrite_dataflow_reshape.cc b/src/relax/transform/rewrite_dataflow_reshape.cc index fdaa2b927e2e..7cd0e46cb328 100644 --- a/src/relax/transform/rewrite_dataflow_reshape.cc +++ b/src/relax/transform/rewrite_dataflow_reshape.cc @@ -110,7 +110,7 @@ class DataflowReshapeRewriter : public ExprMutator { bool IsCallingTIRReshape(const CallNode* call, Expr inp) { const GlobalVar& global_var = Downcast(call->args[0]); const auto* func = mod_->functions.Get(global_var).value().as(); - ICHECK_NOTNULL(func); + TVM_FFI_ICHECK_NOTNULL(func); if (!HasReshapePattern(ffi::GetRef(func))) { return false; } @@ -119,14 +119,14 @@ class DataflowReshapeRewriter : public ExprMutator { // as the number of elements in the result. There are operators that could have a reshape // pattern that don't meet this requirement (e.g. strided_slice), and they should not be // converted to reshape. - ICHECK(inp->struct_info_.defined() && call->struct_info_.defined()); + TVM_FFI_ICHECK(inp->struct_info_.defined() && call->struct_info_.defined()); TensorStructInfo inp_sinfo = Downcast(inp->struct_info_.value()); TensorStructInfo res_sinfo = Downcast(call->struct_info_.value()); if (inp_sinfo->IsUnknownDtype() || inp_sinfo->dtype != res_sinfo->dtype) { return false; } - ICHECK(inp_sinfo->shape.defined() && res_sinfo->shape.defined()); + TVM_FFI_ICHECK(inp_sinfo->shape.defined() && res_sinfo->shape.defined()); if (inp_sinfo->IsUnknownNdim() || res_sinfo->IsUnknownNdim()) { return false; } diff --git a/src/relax/transform/run_codegen.cc b/src/relax/transform/run_codegen.cc index 71d557d031cf..d02e62fc4d2a 100644 --- a/src/relax/transform/run_codegen.cc +++ b/src/relax/transform/run_codegen.cc @@ -92,7 +92,7 @@ class CodeGenRunner : ExprMutator { // Some backends (e.g. TensorRT) expect constants to be passed when they are instantiated ffi::Map constants; for (const auto& [constant, name] : constant_names) { - ICHECK(!constants.count(name)) << "More than one constant with the name " << name; + TVM_FFI_ICHECK(!constants.count(name)) << "More than one constant with the name " << name; constants.Set(name, constant->data); } out_mod = WithAttr(out_mod, tvm::attr::kConstNameToConstant, std::move(constants)); @@ -132,7 +132,7 @@ class CodeGenRunner : ExprMutator { // Remove the global symbol and codegen attributes from the function so that it can be // removed the module. const auto RemoveFuncAttrFunc = tvm::ffi::Function::GetGlobal("ir.BaseFuncWithoutAttr"); - ICHECK(RemoveFuncAttrFunc.has_value()); + TVM_FFI_ICHECK(RemoveFuncAttrFunc.has_value()); func = (*RemoveFuncAttrFunc)(func, tvm::attr::kGlobalSymbol).cast(); func = (*RemoveFuncAttrFunc)(func, attr::kCodegen).cast(); builder_->UpdateFunction(gvar, func); @@ -196,7 +196,7 @@ class CodeGenRunner : ExprMutator { // Get the codegen with its ffi key. ffi::String codegen_name = "relax.ext." + target; const auto codegen = tvm::ffi::Function::GetGlobal(codegen_name); - ICHECK(codegen.has_value()) << "Codegen is not found: " << codegen_name << "\n"; + TVM_FFI_ICHECK(codegen.has_value()) << "Codegen is not found: " << codegen_name << "\n"; ffi::Array compiled_functions = (*codegen)(functions, options, constant_names).cast>(); diff --git a/src/relax/transform/specialize_primfunc_based_on_callsite.cc b/src/relax/transform/specialize_primfunc_based_on_callsite.cc index bb2855b4b0bb..10fc575e729d 100644 --- a/src/relax/transform/specialize_primfunc_based_on_callsite.cc +++ b/src/relax/transform/specialize_primfunc_based_on_callsite.cc @@ -42,7 +42,7 @@ using tvm::tir::Buffer; static ffi::Array GetShapeFromTensorStructInfo(const TensorStructInfo& tensor_sinfo) { auto shape = tensor_sinfo->GetShape(); - ICHECK(shape.defined()); + TVM_FFI_ICHECK(shape.defined()); return shape.value(); } @@ -85,10 +85,10 @@ class SpecializeTIRCallArgs : ExprMutator { for (size_t i = 0; i < args.size(); ++i) { auto sinfo = GetStructInfo(args[i]); - CHECK(sinfo->IsInstance()) + TVM_FFI_ICHECK(sinfo->IsInstance()) << "Expected Tensor struct Info for call :" << call->op; auto tensor_sinfo = Downcast(sinfo); - CHECK(tensor_sinfo->shape.defined()) << "Shape undefined for call:" << call->args[0]; + TVM_FFI_ICHECK(tensor_sinfo->shape.defined()) << "Shape undefined for call:" << call->args[0]; ffi::String scope = "global"; if (tensor_sinfo->vdevice.defined()) { scope = tensor_sinfo->vdevice.value()->memory_scope; @@ -115,7 +115,7 @@ class SpecializeTIRCallArgs : ExprMutator { tir::decl_buffer(GetShapeFromTensorStructInfo(sinfo), sinfo->dtype, "ret_val", scope); param_map.Set(pfunc->params[pfunc->params.size() - 1], buffer); } else { - ICHECK(out_sinfo->IsInstance()) + TVM_FFI_ICHECK(out_sinfo->IsInstance()) << "Expect output struct info of call_tir to be either TupleStructInfo or " "TensorStructInfo, but got " << out_sinfo; @@ -124,7 +124,7 @@ class SpecializeTIRCallArgs : ExprMutator { ffi::Array sinfo_fields; int index = 0; for (const auto& si : tuple_sinfo->fields) { - ICHECK(si->IsInstance()) + TVM_FFI_ICHECK(si->IsInstance()) << "Fields of TupleStructInfo must be TensorStructInfo for call_tir " "output structinfo, but got " << si; @@ -143,7 +143,7 @@ class SpecializeTIRCallArgs : ExprMutator { auto new_pfunc = Specialize(pfunc, param_map); for (const auto& [var, buffer] : new_pfunc->buffer_map) { auto* ptr = buffer->data->type_annotation.as(); - ICHECK(ptr) << "Buffer Var's type annotation must be of PointerType"; + TVM_FFI_ICHECK(ptr) << "Buffer Var's type annotation must be of PointerType"; } auto new_prim_func = WithAttr(new_pfunc, "scoped", Integer(1)); updates_->Add(gv, new_prim_func); diff --git a/src/relax/transform/split_call_tir_by_pattern.cc b/src/relax/transform/split_call_tir_by_pattern.cc index 49a76d03e5a2..c8a71864fe82 100644 --- a/src/relax/transform/split_call_tir_by_pattern.cc +++ b/src/relax/transform/split_call_tir_by_pattern.cc @@ -63,7 +63,7 @@ class ForMatcher : public TensorizeComparator { bool Match(const For& top) { const ForNode* pattern_top = pattern_->body.as()->block->body.as(); - ICHECK(pattern_top) << "Invalid pattern function"; + TVM_FFI_ICHECK(pattern_top) << "Invalid pattern function"; if (!VisitStmt(top, ffi::GetRef(pattern_top))) { return false; } @@ -72,7 +72,7 @@ class ForMatcher : public TensorizeComparator { auto it = pattern_->buffer_map.find(arg); if (it != pattern_->buffer_map.end()) { auto itt = rhs_buffer_map_.find((*it).second); - ICHECK(itt != rhs_buffer_map_.end()); + TVM_FFI_ICHECK(itt != rhs_buffer_map_.end()); evaluated_buffers.push_back(itt->second); } } @@ -520,7 +520,7 @@ class BlockRemover : public StmtExprMutator { SBlock block = Downcast(StmtExprMutator::VisitStmt_(op)); ObjectPtr n = ffi::make_object(*block.operator->()); if (op->name_hint != "root") { - ICHECK(block_partition.count(ffi::GetRef(op))); + TVM_FFI_ICHECK(block_partition.count(ffi::GetRef(op))); bool block_is_library = block_partition[ffi::GetRef(op)]->value; if (!(is_library_part_ ^ block_is_library)) { n->body = block->body; @@ -578,7 +578,7 @@ std::pair> SplitFunctions( return {func, std::nullopt}; } ffi::Array codegen_result = f_codegen(match_results); - ICHECK(codegen_result.size() == 3); + TVM_FFI_ICHECK(codegen_result.size() == 3); ffi::String library_code = Downcast(codegen_result[0]); int num_matched_ops = Downcast(codegen_result[1])->value; ffi::Array func1_args = Downcast>(codegen_result[2]); @@ -609,9 +609,9 @@ std::pair> SplitFunctions( // Step 3. Craft the first function. ffi::Array new_params1; std::vector arg_partition1; - ICHECK_LE(func1_args.size(), partitioner.input1.size()); + TVM_FFI_ICHECK_LE(func1_args.size(), partitioner.input1.size()); for (const auto& buffer : func1_args) { - ICHECK(partitioner.input1.find(buffer) != partitioner.input1.end()); + TVM_FFI_ICHECK(partitioner.input1.find(buffer) != partitioner.input1.end()); for (size_t i = 0; i < func->params.size(); i++) { if (func->buffer_map[func->params[i]].same_as(buffer)) { new_params1.push_back(func->params[i]); @@ -729,7 +729,7 @@ class SplitMutator : public ExprMutator { tvm::BaseFunc lib_func = CodegenWithLibrary(split_funcs.first.get(), gv->name_hint); if (lib_func->IsInstance()) return ffi::GetRef(op); // Update the function in the module with the library kernel - ICHECK(lib_func->IsInstance()); + TVM_FFI_ICHECK(lib_func->IsInstance()); builder_->UpdateFunction(gv, lib_func); // emit the call to the library kernel ObjectPtr new_call = ffi::make_object(*call.operator->()); @@ -739,7 +739,7 @@ class SplitMutator : public ExprMutator { } tir::PrimFunc func1 = s_tir::RenewDefs(split_funcs.first); tir::PrimFunc func2 = s_tir::RenewDefs(split_funcs.second.value()); - ICHECK(arg_partition.size() == 2); + TVM_FFI_ICHECK(arg_partition.size() == 2); // emit the first call to the library kernel ffi::Array args1; for (int p : arg_partition[0]) { @@ -748,7 +748,7 @@ class SplitMutator : public ExprMutator { // replace the function in the module with the library kernel tvm::BaseFunc lib_func = CodegenWithLibrary(func1.get(), gv->name_hint); if (lib_func->IsInstance()) return ffi::GetRef(op); - ICHECK(lib_func->IsInstance()); + TVM_FFI_ICHECK(lib_func->IsInstance()); builder_->UpdateFunction(gv, lib_func); tir::Buffer intermediate_buffer = func1->buffer_map.at(func1->params.back()); DataType dtype = intermediate_buffer->dtype; diff --git a/src/relax/transform/split_layout_rewrite_preproc.cc b/src/relax/transform/split_layout_rewrite_preproc.cc index 04fc385b989c..66f706e47882 100644 --- a/src/relax/transform/split_layout_rewrite_preproc.cc +++ b/src/relax/transform/split_layout_rewrite_preproc.cc @@ -37,7 +37,7 @@ class SplitPrimFuncLayoutRewrite : public StmtMutator { public: explicit SplitPrimFuncLayoutRewrite(const PrimFunc& func) : original_func_(func) {} std::tuple, PrimFunc> Transform(const PrimFunc& func) { - ICHECK(func->body.as()) + TVM_FFI_ICHECK(func->body.as()) << "The body of the primfunc should be a root block."; const auto& block = func->body.as()->block; visit_root_block(block.get()); @@ -57,7 +57,7 @@ class SplitPrimFuncLayoutRewrite : public StmtMutator { PrimFunc create_layout_rewrite_preproc_func() const { // Step 1: Check the number of pre_rewrite_buffers and post_rewrite_buffers - ICHECK(rewrite_infos_.size() > 0) << "There should be at least one buffer rewrite."; + TVM_FFI_ICHECK(rewrite_infos_.size() > 0) << "There should be at least one buffer rewrite."; // Step 2: Create the params for the new PrimFunc ffi::Array params; @@ -73,7 +73,7 @@ class SplitPrimFuncLayoutRewrite : public StmtMutator { } // Step 3: Create the body for the new PrimFunc - ICHECK(layout_rewrite_preproc_stmts_.size() > 0) + TVM_FFI_ICHECK(layout_rewrite_preproc_stmts_.size() > 0) << "There should be at least one layout rewrite preproc stmt."; Stmt body = layout_rewrite_preproc_stmts_.size() == 1 ? layout_rewrite_preproc_stmts_[0] : SeqStmt(layout_rewrite_preproc_stmts_); @@ -104,7 +104,7 @@ class SplitPrimFuncLayoutRewrite : public StmtMutator { ffi::Map buffer_map = original_func_->buffer_map; for (const auto& info : rewrite_infos_) { const Var& param = params[info.buffer_index]; - ICHECK(buffer_map[param] == info.pre_rewrite_buffer); + TVM_FFI_ICHECK(buffer_map[param] == info.pre_rewrite_buffer); buffer_map.Set(param, info.post_rewrite_buffer); } @@ -150,7 +150,7 @@ class SplitPrimFuncLayoutRewrite : public StmtMutator { for (const auto& stmt : seq_stmt->seq) { current_subtree_ = 0; Stmt new_stmt = this->VisitStmt(stmt); - ICHECK(current_subtree_ != 0) << "There should be at least a block in the subtree."; + TVM_FFI_ICHECK(current_subtree_ != 0) << "There should be at least a block in the subtree."; if (current_subtree_ == 1) { layout_rewrite_preproc_stmts_.push_back(new_stmt); } else { @@ -160,7 +160,7 @@ class SplitPrimFuncLayoutRewrite : public StmtMutator { } else { current_subtree_ = 0; this->VisitStmt(body); - ICHECK(current_subtree_ == -1) + TVM_FFI_ICHECK(current_subtree_ == -1) << "There should be a compute block if there is only one subtree under the root."; } } @@ -173,19 +173,22 @@ class SplitPrimFuncLayoutRewrite : public StmtMutator { if (current_subtree_ == 0) { current_subtree_ = is_layout_rewrite_preproc ? 1 : -1; } else if (current_subtree_ == 1) { - CHECK(is_layout_rewrite_preproc) + TVM_FFI_ICHECK(is_layout_rewrite_preproc) << "There is a layout rewrite block in the subtree, but meet a non-layout rewrite block."; } else { - CHECK(!is_layout_rewrite_preproc) + TVM_FFI_ICHECK(!is_layout_rewrite_preproc) << "There is a non-layout rewrite block in the subtree, but meet a layout rewrite block."; } if (is_layout_rewrite_preproc) { - ICHECK(op->reads.size() == 1) << "There should be only one read buffer in the layout rewrite"; - ICHECK(op->writes.size() == 1) + TVM_FFI_ICHECK(op->reads.size() == 1) + << "There should be only one read buffer in the layout rewrite"; + TVM_FFI_ICHECK(op->writes.size() == 1) << "There should be only one write buffer in the layout rewrite"; - ICHECK(op->alloc_buffers.empty()) << "There should be no alloc buffer in the layout rewrite"; - ICHECK(op->match_buffers.empty()) << "There should be no match buffer in the layout rewrite"; + TVM_FFI_ICHECK(op->alloc_buffers.empty()) + << "There should be no alloc buffer in the layout rewrite"; + TVM_FFI_ICHECK(op->match_buffers.empty()) + << "There should be no match buffer in the layout rewrite"; const Buffer& preproc_buffer = op->reads[0]->buffer; int buffer_index = -1; for (size_t i = 0; i < original_func_->params.size(); ++i) { @@ -195,7 +198,8 @@ class SplitPrimFuncLayoutRewrite : public StmtMutator { break; } } - ICHECK(buffer_index != -1) << "The preproc buffer is not found in the original primfunc."; + TVM_FFI_ICHECK(buffer_index != -1) + << "The preproc buffer is not found in the original primfunc."; rewrite_infos_.push_back( RewriteInfo{buffer_index, op->reads[0]->buffer, op->writes[0]->buffer}); @@ -287,7 +291,7 @@ class SplitLayoutRewritePreproc : public ExprMutator { GlobalVar compute_gv = builder_->AddFunction(compute_func, gv->name_hint + "_prepacked"); // Step 4. Get rewrite infos auto rewrite_infos_it = rewrite_infos_.find(gv.get()); - ICHECK(rewrite_infos_it != rewrite_infos_.end()) + TVM_FFI_ICHECK(rewrite_infos_it != rewrite_infos_.end()) << "Rewrite infos are not found for " << gv->name_hint; const auto& rewrite_infos = rewrite_infos_it->second; @@ -299,8 +303,9 @@ class SplitLayoutRewritePreproc : public ExprMutator { preproc_args.push_back(call_tir_args[info.buffer_index]); tir::Buffer rewritten_buffer = info.post_rewrite_buffer; for (const auto& shape_expr : rewritten_buffer->shape) { - CHECK(shape_expr.as()) << "Currently does not support rewrite buffer with " - "dynamic shape."; + TVM_FFI_ICHECK(shape_expr.as()) + << "Currently does not support rewrite buffer with " + "dynamic shape."; } preproc_sinfo_list.push_back( TensorStructInfo(ShapeExpr(rewritten_buffer->shape), rewritten_buffer->dtype)); diff --git a/src/relax/transform/static_plan_block_memory.cc b/src/relax/transform/static_plan_block_memory.cc index 2a41687983a1..b4008ec2cdad 100644 --- a/src/relax/transform/static_plan_block_memory.cc +++ b/src/relax/transform/static_plan_block_memory.cc @@ -205,7 +205,8 @@ class TokenAllocatorMixed { */ ffi::Optional RequestReuse(StorageToken prototype) { // Step 0. Sanity check: the prototype token is supposed not to be allocated with actual storage - ICHECK_EQ(prototype->storage_id, -1) << "The token is expected not to be allocated before."; + TVM_FFI_ICHECK_EQ(prototype->storage_id, -1) + << "The token is expected not to be allocated before."; // If the prototype has no reference at all, feel free to allocate new storage. // The unused binding can be removed by cleaning passes. if (prototype->ref_counter == 0) { @@ -224,7 +225,7 @@ class TokenAllocatorMixed { for (; begin != end; ++begin) { StorageToken available_token = begin->second; if (analyzer_->CanProveEqual(prototype->bytes, available_token->bytes)) { - ICHECK_EQ(available_token->ref_counter, 0) + TVM_FFI_ICHECK_EQ(available_token->ref_counter, 0) << "Available tokens are expected to have 0 reference."; available_token->ref_counter = prototype->ref_counter; pool.erase(begin); @@ -240,9 +241,9 @@ class TokenAllocatorMixed { // Step 3. Search for memory block that equals or is larger than the requested size. if (mid != end) { StorageToken available_token = mid->second; - ICHECK_EQ(available_token->ref_counter, 0) + TVM_FFI_ICHECK_EQ(available_token->ref_counter, 0) << "Available tokens are expected to have 0 reference."; - ICHECK_LE(size, available_token->const_bytes()); + TVM_FFI_ICHECK_LE(size, available_token->const_bytes()); available_token->ref_counter = prototype->ref_counter; pool.erase(mid); return available_token; @@ -252,10 +253,10 @@ class TokenAllocatorMixed { --mid; StorageToken available_token = mid->second; int64_t available_size = available_token->const_bytes(); - ICHECK_EQ(available_token->ref_counter, 0) + TVM_FFI_ICHECK_EQ(available_token->ref_counter, 0) << "Available tokens are expected to have 0 reference."; - ICHECK_GE(available_size, 0); - ICHECK_GE(size, available_size); + TVM_FFI_ICHECK_GE(available_size, 0); + TVM_FFI_ICHECK_GE(size, available_size); // Enlarge the token size. available_token->bytes = tir::make_const(DataType::Int(64), size); available_token->ref_counter = prototype->ref_counter; @@ -274,7 +275,8 @@ class TokenAllocatorMixed { */ StorageToken Alloc(StorageToken prototype, int storage_id) { // Sanity check: the prototype token is supposed not to be allocated with actual storage yet - ICHECK_EQ(prototype->storage_id, -1) << "The token is expected not to be allocated before."; + TVM_FFI_ICHECK_EQ(prototype->storage_id, -1) + << "The token is expected not to be allocated before."; prototype->storage_id = storage_id; full_pool_.push_back(prototype); return prototype; @@ -286,9 +288,10 @@ class TokenAllocatorMixed { */ void Release(StorageToken token) { // Sanity check: the token has been allocated with actual storage, and should have 0 reference. - ICHECK_GE(token->storage_id, 0) + TVM_FFI_ICHECK_GE(token->storage_id, 0) << "The token to be released is expected to be allocated before"; - ICHECK_EQ(token->ref_counter, 0) << "The token to be released is expected to have 0 reference."; + TVM_FFI_ICHECK_EQ(token->ref_counter, 0) + << "The token to be released is expected to have 0 reference."; available_pool_[{token->storage_scope, token->dtype}].insert({token->const_bytes(), token}); } @@ -338,8 +341,8 @@ class StorageAllocatorBaseVisitor : public ExprVisitor { // We maintain a block stack for token allocation-site and use-site check. block_stack_.push_back(block); ExprVisitor::VisitBindingBlock_(block); - ICHECK(!block_stack_.empty()); - ICHECK(block_stack_.back() == block); + TVM_FFI_ICHECK(!block_stack_.empty()); + TVM_FFI_ICHECK(block_stack_.back() == block); block_stack_.pop_back(); } @@ -353,8 +356,8 @@ class StorageAllocatorBaseVisitor : public ExprVisitor { // We maintain a block stack for token allocation-site and use-site check. block_stack_.push_back(block); ExprVisitor::VisitBindingBlock_(block); - ICHECK(!block_stack_.empty()); - ICHECK(block_stack_.back() == block); + TVM_FFI_ICHECK(!block_stack_.empty()); + TVM_FFI_ICHECK(block_stack_.back() == block); block_stack_.pop_back(); } @@ -375,10 +378,10 @@ class StorageAllocatorBaseVisitor : public ExprVisitor { token_map_[tuple_item] = Tokens(); return; } - ICHECK(tokens.IsNested()); + TVM_FFI_ICHECK(tokens.IsNested()); ffi::Array field_tokens = tokens.NestedArray(); - ICHECK_GT(static_cast(field_tokens.size()), tuple_item->index); - ICHECK_GE(tuple_item->index, 0); + TVM_FFI_ICHECK_GT(static_cast(field_tokens.size()), tuple_item->index); + TVM_FFI_ICHECK_GE(tuple_item->index, 0); SetTokens(tuple_item, field_tokens[tuple_item->index]); } @@ -565,7 +568,7 @@ class StorageAllocatorInit : public StorageAllocatorBaseVisitor { call->op == call_tir_dyn_op) { ffi::Array args = call->op == call_tir_dyn_op ? Downcast(call->args[1])->fields : call->args; - ICHECK(!block_stack_.empty()); + TVM_FFI_ICHECK(!block_stack_.empty()); for (const Expr& arg : call->args) { Tokens tokens = GetTokensWithAllocSiteCheck(arg, block_stack_.back()); ForEachLeaf(tokens, [](StorageToken token) { token->ref_counter += 1; }); @@ -628,12 +631,12 @@ class StorageAllocatorInit : public StorageAllocatorBaseVisitor { // - the tensor has known dtype; // - no storage token was created for this call before. const auto* sinfo = call->struct_info_.as(); - ICHECK_NOTNULL(sinfo); + TVM_FFI_ICHECK_NOTNULL(sinfo); const auto* shape = sinfo->shape.as(); - ICHECK_NOTNULL(shape); - ICHECK(!sinfo->IsUnknownDtype()); - ICHECK(sinfo->dtype == Downcast(call->args[1])->value); - ICHECK(!token_map_.count(call)); + TVM_FFI_ICHECK_NOTNULL(shape); + TVM_FFI_ICHECK(!sinfo->IsUnknownDtype()); + TVM_FFI_ICHECK(sinfo->dtype == Downcast(call->args[1])->value); + TVM_FFI_ICHECK(!token_map_.count(call)); // Use the upper bounds of TIR vars as their values. The upper bound shape can still be dynamic // if the upper bounds of some variables are not provided. @@ -653,7 +656,7 @@ class StorageAllocatorInit : public StorageAllocatorBaseVisitor { Tokens tokens(token); SetTokens(call, tokens); - ICHECK(!block_stack_.empty()); + TVM_FFI_ICHECK(!block_stack_.empty()); token2block_[token.get()] = block_stack_.back(); return tokens; } @@ -685,7 +688,7 @@ class StorageAllocatorInit : public StorageAllocatorBaseVisitor { Tokens tokens = GetTokens(expr); ForEachLeaf(tokens, [this, cur_block](StorageToken token) { auto it = this->token2block_.find(token.get()); - ICHECK(it != this->token2block_.end()); + TVM_FFI_ICHECK(it != this->token2block_.end()); if (it->second != cur_block) { this->DiscardToken(token); } @@ -780,7 +783,7 @@ class StorageAllocator : public StorageAllocatorBaseVisitor { // Sanity check: each token allocated inside the block should not be // referenced by anyone at the end of the block. for (const StorageTokenNode* token : block2tokens[block]) { - ICHECK_EQ(token->ref_counter, 0); + TVM_FFI_ICHECK_EQ(token->ref_counter, 0); } } @@ -788,14 +791,14 @@ class StorageAllocator : public StorageAllocatorBaseVisitor { static const Op& alloc_tensor_op = Op::Get("relax.builtin.alloc_tensor"); if (call->op == alloc_tensor_op) { auto it = token_map_.find(call); - ICHECK(it != token_map_.end()); + TVM_FFI_ICHECK(it != token_map_.end()); if (it->second.IsNull()) { // IsNull being true means the token was discarded, and this alloc_tensor // is not considered by the planning. return; } - ICHECK(it->second.IsLeaf()); + TVM_FFI_ICHECK(it->second.IsLeaf()); StorageToken new_token = this->RequestReuseOrAlloc(it->second.LeafValue()); // Record that this alloc_tensor is using the token. @@ -803,7 +806,7 @@ class StorageAllocator : public StorageAllocatorBaseVisitor { token2cur_tensor_[new_token.get()].push_back(binding->var); SetTokens(call, Tokens(new_token)); // Record that the token is allocated in the current block. - ICHECK(!block_stack_.empty()); + TVM_FFI_ICHECK(!block_stack_.empty()); std::vector& block_tokens = block2tokens[block_stack_.back()]; if (std::find(block_tokens.begin(), block_tokens.end(), new_token.get()) == block_tokens.end()) { @@ -812,13 +815,13 @@ class StorageAllocator : public StorageAllocatorBaseVisitor { return; } else if (IsInplaceMemoryOp(call->op)) { Tokens tokens = GetTokens(call->args[0]); - ICHECK(!tokens.IsNested()); + TVM_FFI_ICHECK(!tokens.IsNested()); if (tokens.IsLeaf()) { // If the input is using a token, record that the reshape uses the token as well. token2cur_tensor_[tokens.LeafValue().get()].push_back(binding->var); SetTokens(call, tokens); } else { - ICHECK(token_map_[call].IsNull()); + TVM_FFI_ICHECK(token_map_[call].IsNull()); } return; } @@ -829,7 +832,7 @@ class StorageAllocator : public StorageAllocatorBaseVisitor { for (const Expr& arg : call->args) { Tokens tokens = GetTokens(arg); ForEachLeaf(tokens, [this](StorageToken token) { - ICHECK_GT(token->ref_counter, 0); + TVM_FFI_ICHECK_GT(token->ref_counter, 0); token->ref_counter -= 1; this->CheckForRelease(token); }); @@ -852,13 +855,13 @@ class StorageAllocator : public StorageAllocatorBaseVisitor { */ void CheckForRelease(StorageToken token) { // Sanity check: the token was allocated before and has non-negative reference. - ICHECK_GE(token->storage_id, 0); - ICHECK_GE(token->ref_counter, 0); + TVM_FFI_ICHECK_GE(token->storage_id, 0); + TVM_FFI_ICHECK_GE(token->ref_counter, 0); if (token->ref_counter == 0) { allocator_.Release(token); auto it = token2cur_tensor_.find(token.get()); - ICHECK(it != token2cur_tensor_.end()); + TVM_FFI_ICHECK(it != token2cur_tensor_.end()); token2cur_tensor_.erase(it); } } @@ -920,10 +923,10 @@ class StorageAllocationRewriter : public ExprMutator { auto it = alloc_tensor2token_.find(call); if (it != alloc_tensor2token_.end()) { // Case 1. This `alloc_tensor` is planned for memory reuse. - ICHECK_EQ(call->op, alloc_tensor_op); + TVM_FFI_ICHECK_EQ(call->op, alloc_tensor_op); const auto* sinfo = call->struct_info_.as(); - ICHECK_NOTNULL(sinfo); - ICHECK_NOTNULL(sinfo->shape.as()); + TVM_FFI_ICHECK_NOTNULL(sinfo); + TVM_FFI_ICHECK_NOTNULL(sinfo->shape.as()); PrimValue runtime_device_index = Downcast(call->args[2]); // If the token is visited for the first time, create a storage variable using @@ -959,13 +962,13 @@ class StorageAllocationRewriter : public ExprMutator { // allocate a tensor out from it with the actual symbolic shape. const auto* sinfo = call->struct_info_.as(); - ICHECK_NOTNULL(sinfo); + TVM_FFI_ICHECK_NOTNULL(sinfo); const auto* shape = sinfo->shape.as(); - ICHECK_NOTNULL(shape); + TVM_FFI_ICHECK_NOTNULL(shape); ffi::Array upper_bounded_shape = GetUpperBoundShape(shape->values, &ana_, dom_map_); if (!IsStaticShape(shape->values)) { - ICHECK(!sinfo->IsUnknownDtype()); - ICHECK_EQ(sinfo->dtype, Downcast(call->args[1])->value); + TVM_FFI_ICHECK(!sinfo->IsUnknownDtype()); + TVM_FFI_ICHECK_EQ(sinfo->dtype, Downcast(call->args[1])->value); PrimExpr bytes = upper_bounded_shape[0]; for (int i = 1; i < static_cast(upper_bounded_shape.size()); ++i) { bytes *= upper_bounded_shape[i]; @@ -1044,7 +1047,7 @@ PrimExpr GetTextureMemorySizeFromVDevice(ffi::Array pshape, DataType d struct Shape { const ffi::Array& shape; int64_t operator[](size_t i) const { - ICHECK(tir::as_const_int(shape[i])) << "Dymamic shapes not suported over texture now"; + TVM_FFI_ICHECK(tir::as_const_int(shape[i])) << "Dymamic shapes not suported over texture now"; return *tir::as_const_int(shape[i]); } int size() { return this->shape.size(); } diff --git a/src/relax/transform/to_mixed_precision.cc b/src/relax/transform/to_mixed_precision.cc index 66a148e593ca..5c371d618159 100644 --- a/src/relax/transform/to_mixed_precision.cc +++ b/src/relax/transform/to_mixed_precision.cc @@ -147,7 +147,7 @@ class DTypeDecisionCollector : public ExprVisitor { // merge the message for all vars in the expr list void RequireArgsToType(ffi::Array args, ffi::Array to) { - ICHECK(args.size() == to.size()) << "Invalid target dtypes"; + TVM_FFI_ICHECK(args.size() == to.size()) << "Invalid target dtypes"; for (size_t i = 0; i < args.size(); ++i) { auto fvisitleaf = [&](const Expr& expr, NType to) { if (const auto* var = expr.as()) { @@ -156,7 +156,7 @@ class DTypeDecisionCollector : public ExprVisitor { // Constant can be casted anyway, so we don't need to do anything here return; } else { - LOG(FATAL) << "Unsupported argument type: " << expr->GetTypeKey(); + TVM_FFI_THROW(InternalError) << "Unsupported argument type: " << expr->GetTypeKey(); } }; DecomposeNestedMsg(args[i], to[i], fvisitleaf); @@ -202,7 +202,7 @@ class DTypeDecisionCollector : public ExprVisitor { // require inputs to be fp32 (the original dtype) RequireArgsToType(call_node->args, fp32_); } else { - LOG(FATAL) << "Unsupported TMixedPrecisionPolicy: " << policy; + TVM_FFI_THROW(InternalError) << "Unsupported TMixedPrecisionPolicy: " << policy; } } @@ -219,7 +219,7 @@ class DTypeDecisionCollector : public ExprVisitor { std::vector require_rhs; const TupleStructInfoNode* sinfo = tuple_get_item_node->tuple->struct_info_.as(); - ICHECK(sinfo != nullptr) << "TupleGetItemNode must have TupleStructInfo"; + TVM_FFI_ICHECK(sinfo != nullptr) << "TupleGetItemNode must have TupleStructInfo"; for (size_t i = 0; i < sinfo->fields.size(); ++i) { if (i == static_cast(tuple_get_item_node->index)) { require_rhs.push_back(lhs_type); @@ -311,7 +311,7 @@ class ToMixedPrecisionRewriter : public ExprMutator { Expr RewriteExpr(const Expr& expr, const NType& to) { auto fvisitleaf = [&](const Expr& expr, std::array to) -> Expr { const auto* tensor = GetStructInfoAs(expr); - ICHECK(tensor != nullptr) << "Only support rewriting tensor expr"; + TVM_FFI_ICHECK(tensor != nullptr) << "Only support rewriting tensor expr"; // We only rewrite the expr if the dtype is not the same as the given dtype if (NTypeEqual()(to[0], NTypeFrom(expr))) return expr; // We only rewrite the expr if the dtype is fp16 or fp32, dtypes such as int32, float64 is not @@ -406,7 +406,7 @@ class ToMixedPrecisionRewriter : public ExprMutator { } void CastIfFp16Only(const Var& var) { - ICHECK(builder_->CurrentBlockIsDataFlow()); + TVM_FFI_ICHECK(builder_->CurrentBlockIsDataFlow()); // Get the current remapped var Var cur_var = GetRemapped(var); // Store the tensors that are fp16 only to fp16 @@ -463,7 +463,7 @@ class ToMixedPrecisionRewriter : public ExprMutator { } // var = Call(op) const auto* op_node = call_node->op.as(); - ICHECK(op_node != nullptr); + TVM_FFI_ICHECK(op_node != nullptr); Op op = ffi::GetRef(op_node); if (wrap_param_op.same_as(op)) { // wrap_param @@ -482,7 +482,7 @@ class ToMixedPrecisionRewriter : public ExprMutator { if (policy == kAlways) { opt_new_dtype = fp16_; auto attr_map = Op::GetAttrMap("FInferMixedPrecision"); - ICHECK(attr_map.count(op)); + TVM_FFI_ICHECK(attr_map.count(op)); new_call = attr_map[op](new_call, output_dtype_); } else if (policy == kFollow) { opt_new_dtype = AllFP16Castable(new_call->args) ? fp16_ : fp32_; @@ -503,7 +503,7 @@ class ToMixedPrecisionRewriter : public ExprMutator { } } else { - LOG(FATAL) << "Unsupported TMixedPrecisionPolicy: " << policy; + TVM_FFI_THROW(InternalError) << "Unsupported TMixedPrecisionPolicy: " << policy; } Expr new_value = new_call; diff --git a/src/relax/transform/topological_sort.cc b/src/relax/transform/topological_sort.cc index 114af668b980..74b65f6648e9 100644 --- a/src/relax/transform/topological_sort.cc +++ b/src/relax/transform/topological_sort.cc @@ -207,7 +207,7 @@ class TopologicalSorter : public ExprMutator { case StartingLocation::FromOutputs: return dependencies_.upstream_requirements; default: - LOG(FATAL) << "Invalid enum value for StartingLocation"; + TVM_FFI_THROW(InternalError) << "Invalid enum value for StartingLocation"; } }(); @@ -226,7 +226,7 @@ class TopologicalSorter : public ExprMutator { case StartingLocation::FromOutputs: return dependencies_.downstream_users; default: - LOG(FATAL) << "Invalid enum value for StartingLocation"; + TVM_FFI_THROW(InternalError) << "Invalid enum value for StartingLocation"; } }(); @@ -242,7 +242,7 @@ class TopologicalSorter : public ExprMutator { case StartingLocation::FromOutputs: return {OutputNode()}; default: - LOG(FATAL) << "Invalid enum value for StartingLocation"; + TVM_FFI_THROW(InternalError) << "Invalid enum value for StartingLocation"; } }(); @@ -264,7 +264,7 @@ class TopologicalSorter : public ExprMutator { } auto it = backward_edge_lookup.find(adjacent_var); - ICHECK(it != backward_edge_lookup.end()); + TVM_FFI_ICHECK(it != backward_edge_lookup.end()); const auto& prerequisites = it->second; return std::all_of(prerequisites.begin(), prerequisites.end(), [&visited](const auto& var) { return visited.count(var); }); @@ -291,7 +291,8 @@ class TopologicalSorter : public ExprMutator { break; } default: { - LOG(FATAL) << "Invalid value for TraversalOrder: " << static_cast(order_); + TVM_FFI_THROW(InternalError) + << "Invalid value for TraversalOrder: " << static_cast(order_); } } @@ -305,18 +306,20 @@ class TopologicalSorter : public ExprMutator { push_descendents_to_stack(visiting); } - ICHECK_EQ(to_emit.size(), 0) << "After visiting all bindings, " - << "no bindings should remain to emit. " - << "However, bindings " << + TVM_FFI_ICHECK_EQ(to_emit.size(), 0) + << "After visiting all bindings, " + << "no bindings should remain to emit. " + << "However, bindings " << [&]() { ffi::Array arr; for (const auto& [var, binding] : to_emit) { arr.push_back(var); } return arr; - }() << " still remain after emitting " - << ffi::Array(new_bindings.begin(), new_bindings.end()) - .Map([](const Binding& binding) { return binding->var; }); + }() + << " still remain after emitting " + << ffi::Array(new_bindings.begin(), new_bindings.end()) + .Map([](const Binding& binding) { return binding->var; }); if (starting_location_ == StartingLocation::FromOutputs) { std::reverse(new_bindings.begin(), new_bindings.end()); @@ -345,35 +348,35 @@ Pass TopologicalSort(TraversalOrder order, StartingLocation starting_location) { TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def( - "relax.transform.TopologicalSort", - [](ffi::String order_str, ffi::String direction_str) -> Pass { - TraversalOrder order = [&]() { - if (order_str == "depth-first") { - return TraversalOrder::DepthFirst; - } else if (order_str == "breadth-first") { - return TraversalOrder::BreadthFirst; - } else { - LOG(FATAL) << "ValueError: " - << "Invalid value for traversal order: \"" << order_str << "\". " - << "Allowed values are \"depth-first\" or \"breadth-first\""; - } - }(); - - StartingLocation starting_location = [&]() { - if (direction_str == "from-inputs") { - return StartingLocation::FromInputs; - } else if (direction_str == "from-outputs") { - return StartingLocation::FromOutputs; - } else { - LOG(FATAL) << "ValueError: " - << "Invalid value for starting location: \"" << direction_str << "\". " - << "Allowed values are \"from-inputs\" or \"from-outputs\""; - } - }(); - - return TopologicalSort(order, starting_location); - }); + refl::GlobalDef().def("relax.transform.TopologicalSort", + [](ffi::String order_str, ffi::String direction_str) -> Pass { + TraversalOrder order = [&]() { + if (order_str == "depth-first") { + return TraversalOrder::DepthFirst; + } else if (order_str == "breadth-first") { + return TraversalOrder::BreadthFirst; + } else { + TVM_FFI_THROW(ValueError) + << "Invalid value for traversal order: \"" << order_str << "\". " + << "Allowed values are \"depth-first\" or \"breadth-first\""; + } + }(); + + StartingLocation starting_location = [&]() { + if (direction_str == "from-inputs") { + return StartingLocation::FromInputs; + } else if (direction_str == "from-outputs") { + return StartingLocation::FromOutputs; + } else { + TVM_FFI_THROW(ValueError) + << "Invalid value for starting location: \"" << direction_str + << "\". " + << "Allowed values are \"from-inputs\" or \"from-outputs\""; + } + }(); + + return TopologicalSort(order, starting_location); + }); } } // namespace transform diff --git a/src/relax/transform/utils.cc b/src/relax/transform/utils.cc index 580b3892e57b..32f080cb5124 100644 --- a/src/relax/transform/utils.cc +++ b/src/relax/transform/utils.cc @@ -69,8 +69,7 @@ Function ComposeFunctions(Function func_a, Function func_b) { auto param = func_b->params[0]; bindings.push_back(MatchCast(param, func_a_output, GetStructInfo(param))); } else { - CHECK_EQ(func_a_outputs.size(), func_b->params.size()) - << "ValueError: " + TVM_FFI_CHECK_EQ(func_a_outputs.size(), func_b->params.size(), ValueError) << "Cannot compose functions together. " << "First function produces " << func_a_outputs.size() << " values, " << "but second function expects " << func_b->params.size() << " parameters as input"; diff --git a/src/relax/transform/utils.h b/src/relax/transform/utils.h index 2cce918f4ef7..5d216f3f8425 100644 --- a/src/relax/transform/utils.h +++ b/src/relax/transform/utils.h @@ -74,7 +74,7 @@ class MemoizedExprTranslator : public ::tvm::relax::ExprFunctorsecond; @@ -85,12 +85,12 @@ class MemoizedExprTranslator : public ::tvm::relax::ExprFunctor(vn))); + TVM_FFI_ICHECK(memo_.count(ffi::GetRef(vn))); return memo_[ffi::GetRef(vn)]; } virtual OutputType VisitBinding_(const VarBindingNode* binding) { - ICHECK_EQ(memo_.count(binding->var), 0); + TVM_FFI_ICHECK_EQ(memo_.count(binding->var), 0); auto v = VisitExpr(binding->value); memo_[binding->var] = v; return v; @@ -126,7 +126,7 @@ TVM_DLL IRModule DeadCodeElimination(const IRModule& mod, ffi::ArrayGetAttr(tvm::attr::kGlobalSymbol); - ICHECK(name_node.has_value()) << "Fail to retrieve external symbol."; + TVM_FFI_ICHECK(name_node.has_value()) << "Fail to retrieve external symbol."; return std::string(name_node.value()); } @@ -356,7 +356,7 @@ inline Constant MakeConstantScalar(T value, DataType dtype) { *static_cast(arr->data) = __truncXfYf2__(static_cast(value)); } else { - LOG(FATAL) << "Unsupported dtype " << dtype; + TVM_FFI_THROW(InternalError) << "Unsupported dtype " << dtype; } return Constant(arr); } @@ -369,8 +369,9 @@ inline ffi::Array GetOrderedPositiveAxes(const ffi::Array& axe if (axis_val < 0) { axis_val += ndim; } - ICHECK(axis_val >= 0 && axis_val < ndim) << "axis " << axis << " is out of bounds for array of " - << "dimension " << ndim; + TVM_FFI_ICHECK(axis_val >= 0 && axis_val < ndim) + << "axis " << axis << " is out of bounds for array of " + << "dimension " << ndim; ret.push_back(axis_val); } std::sort(ret.begin(), ret.end()); @@ -379,8 +380,9 @@ inline ffi::Array GetOrderedPositiveAxes(const ffi::Array& axe inline ffi::String GetCodegenName(const std::string& composite_name) { auto delim_pos = composite_name.find("."); - ICHECK(delim_pos != std::string::npos) << "The pattern name for a composite function should " - "start with a compiler name followed by period."; + TVM_FFI_ICHECK(delim_pos != std::string::npos) + << "The pattern name for a composite function should " + "start with a compiler name followed by period."; return composite_name.substr(0, delim_pos); } @@ -404,7 +406,7 @@ inline int GetDeviceIndex(const IRModule& mod, const VDevice& vdevice) { return i; } } - LOG(FATAL) << "The vdevice is not in the ir_module."; + TVM_FFI_THROW(InternalError) << "The vdevice is not in the ir_module."; return -1; } diff --git a/src/relax/utils.cc b/src/relax/utils.cc index 37e53a614ff0..ae4c953fd007 100644 --- a/src/relax/utils.cc +++ b/src/relax/utils.cc @@ -221,7 +221,8 @@ bool IsImpureCall(const Call& call) { if (auto op_ptr = call->op.as()) { auto op = ffi::GetRef(op_ptr); static auto purity_map = Op::GetAttrMap("FPurity"); - ICHECK(purity_map.count(op)) << "Cannot find the registered purity of this op: " << op->name; + TVM_FFI_ICHECK(purity_map.count(op)) + << "Cannot find the registered purity of this op: " << op->name; return !(purity_map[op]->value); } // the StructInfo must be FuncStructInfo @@ -235,7 +236,7 @@ Expr GetBoundValue(const Binding& b) { } else if (auto* match_binding = b.as()) { return match_binding->value; } else { - CHECK(false) << "Invalid binding (should never happen)"; + TVM_FFI_ICHECK(false) << "Invalid binding (should never happen)"; } } diff --git a/src/runtime/const_loader_module.cc b/src/runtime/const_loader_module.cc index f372715d585e..8a125222a028 100644 --- a/src/runtime/const_loader_module.cc +++ b/src/runtime/const_loader_module.cc @@ -61,7 +61,7 @@ class ConstLoaderModuleObj : public ffi::ModuleObj { for (const auto& var : kv.second) { VLOG(1) << "ConstLoaderModuleNode has constant '" << var << "' for function '" << kv.first << "'"; - ICHECK_GT(const_var_tensor_.count(var), 0) + TVM_FFI_ICHECK_GT(const_var_tensor_.count(var), 0) << "ConstLoaderModuleNode is missing entry for constant '" << var << "' for function '" << kv.first << "'"; } @@ -93,7 +93,7 @@ class ConstLoaderModuleObj : public ffi::ModuleObj { // Run the module. // Normally we would only have a limited number of submodules. The runtime // symobl lookup overhead should be minimal. - ICHECK(!this->imports_.empty()); + TVM_FFI_ICHECK(!this->imports_.empty()); for (const Any& it : this->imports_) { ffi::Optional pf = it.cast()->GetFunction(name); if (pf.has_value()) return pf.value(); @@ -113,11 +113,11 @@ class ConstLoaderModuleObj : public ffi::ModuleObj { */ ffi::Array GetRequiredConstants(const std::string& symbol) { ffi::Array ret; - ICHECK_GT(const_vars_by_symbol_.count(symbol), 0U) + TVM_FFI_ICHECK_GT(const_vars_by_symbol_.count(symbol), 0U) << "No constants known for function '" << symbol << "'"; std::vector vars = const_vars_by_symbol_[symbol]; for (const auto& var : vars) { - ICHECK_GT(const_var_tensor_.count(var), 0U) + TVM_FFI_ICHECK_GT(const_var_tensor_.count(var), 0U) << "No such constant variable '" << var << "' for function '" << symbol << "'"; ret.push_back(const_var_tensor_[var]); } @@ -147,7 +147,7 @@ class ConstLoaderModuleObj : public ffi::ModuleObj { // Initialize the module with constants. int ret = (*init)(md).cast(); // Report the error if initialization is failed. - ICHECK_EQ(ret, 0); + TVM_FFI_ICHECK_EQ(ret, 0); break; } } @@ -196,10 +196,10 @@ class ConstLoaderModuleObj : public ffi::ModuleObj { // Load the variables. std::vector variables; - ICHECK(stream.Read(&variables)) << "Loading variable names failed"; + TVM_FFI_ICHECK(stream.Read(&variables)) << "Loading variable names failed"; uint64_t sz; - ICHECK(stream.Read(&sz, sizeof(sz))) << "Loading number of vars failed"; - ICHECK_EQ(static_cast(sz), variables.size()) + TVM_FFI_ICHECK(stream.Read(&sz, sizeof(sz))) << "Loading number of vars failed"; + TVM_FFI_ICHECK_EQ(static_cast(sz), variables.size()) << "The number of variables and ndarray counts must match"; // Load the list of ndarray. std::vector arrays; @@ -211,19 +211,19 @@ class ConstLoaderModuleObj : public ffi::ModuleObj { std::unordered_map const_var_tensor; for (uint64_t i = 0; i < sz; i++) { - ICHECK_EQ(const_var_tensor.count(variables[i]), 0U); + TVM_FFI_ICHECK_EQ(const_var_tensor.count(variables[i]), 0U); const_var_tensor[variables[i]] = arrays[i]; } // Load the symbol to list of required constant variables mapping std::vector symbols; - ICHECK(stream.Read(&symbols)) << "Loading symbols failed"; - ICHECK(stream.Read(&sz, sizeof(sz))) << "Loading number of symbols failed"; - ICHECK_EQ(static_cast(sz), symbols.size()); + TVM_FFI_ICHECK(stream.Read(&symbols)) << "Loading symbols failed"; + TVM_FFI_ICHECK(stream.Read(&sz, sizeof(sz))) << "Loading number of symbols failed"; + TVM_FFI_ICHECK_EQ(static_cast(sz), symbols.size()); std::vector> const_vars; for (uint64_t i = 0; i < sz; i++) { std::vector vars; - ICHECK(stream.Read(&vars)) << "Loading const variables failed"; + TVM_FFI_ICHECK(stream.Read(&vars)) << "Loading const variables failed"; const_vars.push_back(vars); } diff --git a/src/runtime/contrib/amx/amx_config.cc b/src/runtime/contrib/amx/amx_config.cc index 4be9d57811b3..9671994ce0e8 100644 --- a/src/runtime/contrib/amx/amx_config.cc +++ b/src/runtime/contrib/amx/amx_config.cc @@ -100,9 +100,10 @@ TVM_FFI_STATIC_INIT_BLOCK() { int64_t status = syscall(SYS_arch_prctl, ARCH_GET_XCOMP_PERM, &bitmask); if (0 != status) { *rv = 0; - LOG(FATAL) << "errno:" << errno << ", " << strerror(errno); - LOG(FATAL) << "status[0]: " << status << ", bitmask: " << bitmask - << ", XFEATURE_XTILEDATA setup is failed, TMUL feature is not allowed."; + TVM_FFI_THROW(InternalError) << "errno:" << errno << ", " << strerror(errno); + TVM_FFI_THROW(InternalError) + << "status[0]: " << status << ", bitmask: " << bitmask + << ", XFEATURE_XTILEDATA setup is failed, TMUL feature is not allowed."; return; } if (bitmask & XFEATURE_MASK_XTILEDATA) { @@ -114,9 +115,10 @@ TVM_FFI_STATIC_INIT_BLOCK() { // if XFEATURE_XTILEDATA setup is failed, TMUL usage is not allowed if (0 != status) { *rv = 0; - LOG(FATAL) << "errno:" << errno << ", " << strerror(errno); - LOG(FATAL) << "status[1]: " << status << ", bitmask: " << bitmask - << ", XFEATURE_XTILEDATA setup is failed, TMUL usage is not allowed."; + TVM_FFI_THROW(InternalError) << "errno:" << errno << ", " << strerror(errno); + TVM_FFI_THROW(InternalError) + << "status[1]: " << status << ", bitmask: " << bitmask + << ", XFEATURE_XTILEDATA setup is failed, TMUL usage is not allowed."; return; } @@ -124,9 +126,9 @@ TVM_FFI_STATIC_INIT_BLOCK() { // if XFEATURE_XTILEDATA setup is failed, can't use TMUL if (0 != status || !(bitmask & XFEATURE_MASK_XTILEDATA)) { *rv = 0; - LOG(FATAL) << "errno:" << errno << ", " << strerror(errno); - LOG(FATAL) << "status[2]: " << status << ", bitmask: " << bitmask - << ", XFEATURE_XTILEDATA setup is failed, can't use TMUL."; + TVM_FFI_THROW(InternalError) << "errno:" << errno << ", " << strerror(errno); + TVM_FFI_THROW(InternalError) << "status[2]: " << status << ", bitmask: " << bitmask + << ", XFEATURE_XTILEDATA setup is failed, can't use TMUL."; return; } diff --git a/src/runtime/contrib/arm_compute_lib/acl_allocator.cc b/src/runtime/contrib/arm_compute_lib/acl_allocator.cc index b843841f5755..580fa7438398 100644 --- a/src/runtime/contrib/arm_compute_lib/acl_allocator.cc +++ b/src/runtime/contrib/arm_compute_lib/acl_allocator.cc @@ -29,7 +29,7 @@ namespace runtime { namespace contrib { void* ACLAllocator::allocate(size_t size, size_t alignment) { - ICHECK_GT(size, 0) << "Cannot allocate size less than or equal to zero"; + TVM_FFI_ICHECK_GT(size, 0) << "Cannot allocate size less than or equal to zero"; return this->device_api_->AllocWorkspace(this->device_, size, {}); } diff --git a/src/runtime/contrib/arm_compute_lib/acl_runtime.cc b/src/runtime/contrib/arm_compute_lib/acl_runtime.cc index d532f695fc70..819670378fca 100644 --- a/src/runtime/contrib/arm_compute_lib/acl_runtime.cc +++ b/src/runtime/contrib/arm_compute_lib/acl_runtime.cc @@ -78,7 +78,7 @@ class ACLRuntime : public JSONRuntimeBase { * \param consts The constant params from compiled model. */ void Init(const ffi::Array& consts) override { - ICHECK_EQ(consts.size(), const_idx_.size()) + TVM_FFI_ICHECK_EQ(consts.size(), const_idx_.size()) << "The number of input constants must match the number of required."; SetupConstants(consts); BuildEngine(); @@ -134,7 +134,7 @@ class ACLRuntime : public JSONRuntimeBase { for (size_t nid = 0; nid < nodes_.size(); ++nid) { const auto& node = nodes_[nid]; if (found_kernel_node) { - LOG(FATAL) + TVM_FFI_THROW(InternalError) << "Arm Compute Library runtime module only supports one kernel node per function."; } if (node.GetOpType() == "kernel") { @@ -163,7 +163,7 @@ class ACLRuntime : public JSONRuntimeBase { } else if ("concatenate" == op_name) { CreateConcatenateLayer(&layer_, node); } else { - LOG(FATAL) << "Unsupported op: " << op_name; + TVM_FFI_THROW(InternalError) << "Unsupported op: " << op_name; } } } @@ -257,7 +257,8 @@ class ACLRuntime : public JSONRuntimeBase { arm_compute::PadStrideInfo pad_stride_info = MakeACLPadStride(padding, strides); int groups = static_cast(node.GetAttr("groups")); - ICHECK(groups == 1) << "Arm Compute Library NEON convolution only supports group size of 1."; + TVM_FFI_ICHECK(groups == 1) + << "Arm Compute Library NEON convolution only supports group size of 1."; arm_compute::ActivationLayerInfo act_info; if (node.HasAttr("activation_type")) { @@ -273,7 +274,7 @@ class ACLRuntime : public JSONRuntimeBase { size_t num_inputs = inputs.size(); bool has_bias; if (node.GetOpName() == "qnn.conv2d") { - ICHECK(num_inputs >= 8U && num_inputs <= 9U) + TVM_FFI_ICHECK(num_inputs >= 8U && num_inputs <= 9U) << "Quantized convolution requires 9 inputs with a bias, 8 inputs without."; has_bias = num_inputs == 9; layer->inputs.push_back(MakeACLTensorFromJSONEntry(inputs[0], &inputs[4], &inputs[2])); @@ -284,7 +285,7 @@ class ACLRuntime : public JSONRuntimeBase { layer->outputs.push_back( MakeACLTensorFromJSONNode(node, &inputs[6 + has_bias], &inputs[7 + has_bias])); } else { - ICHECK(num_inputs >= 2U && num_inputs <= 3U) + TVM_FFI_ICHECK(num_inputs >= 2U && num_inputs <= 3U) << "Convolution requires 3 inputs with a bias, 2 inputs without."; has_bias = num_inputs == 3; for (const auto& i : inputs) { @@ -329,7 +330,7 @@ class ACLRuntime : public JSONRuntimeBase { size_t num_inputs = inputs.size(); bool has_bias; if (node.GetOpName() == "qnn.depthwise_conv2d") { - ICHECK(num_inputs >= 8U && num_inputs <= 9U) + TVM_FFI_ICHECK(num_inputs >= 8U && num_inputs <= 9U) << "Quantized convolution requires 9 inputs with a bias, 8 inputs without."; has_bias = num_inputs == 9; layer->inputs.push_back(MakeACLTensorFromJSONEntry(inputs[0], &inputs[4], &inputs[2])); @@ -340,7 +341,7 @@ class ACLRuntime : public JSONRuntimeBase { layer->outputs.push_back( MakeACLTensorFromJSONNode(node, &inputs[6 + has_bias], &inputs[7 + has_bias])); } else { - ICHECK(num_inputs >= 2U && num_inputs <= 3U) + TVM_FFI_ICHECK(num_inputs >= 2U && num_inputs <= 3U) << "Convolution requires 3 inputs with a bias, 2 inputs without."; has_bias = num_inputs == 3; for (const auto& i : inputs) { @@ -376,7 +377,7 @@ class ACLRuntime : public JSONRuntimeBase { size_t num_inputs = inputs.size(); bool has_bias; if (node.GetOpName() == "qnn.dense") { - ICHECK(num_inputs >= 8U && num_inputs <= 9U) + TVM_FFI_ICHECK(num_inputs >= 8U && num_inputs <= 9U) << "Quantized fully connected (dense) layer requires 9 inputs with a bias, 8 inputs " "without."; has_bias = num_inputs == 9; @@ -388,7 +389,7 @@ class ACLRuntime : public JSONRuntimeBase { layer->outputs.push_back( MakeACLTensorFromJSONNode(node, &inputs[6 + has_bias], &inputs[7 + has_bias])); } else { - ICHECK(num_inputs >= 2U && num_inputs <= 3U) + TVM_FFI_ICHECK(num_inputs >= 2U && num_inputs <= 3U) << "Fully connected (dense) layer requires 3 inputs with a bias, 2 inputs without."; has_bias = num_inputs == 3; for (const auto& i : inputs) { @@ -437,10 +438,10 @@ class ACLRuntime : public JSONRuntimeBase { } else if (node.GetOpName() == "nn.l2_pool2d") { pool_type = arm_compute::PoolingType::L2; } else { - LOG(FATAL) << "Pooling type not supported"; + TVM_FFI_THROW(InternalError) << "Pooling type not supported"; } - ICHECK(dilation.size() == 2 && dilation[0] == 1 && dilation[1] == 1) + TVM_FFI_ICHECK(dilation.size() == 2 && dilation[0] == 1 && dilation[1] == 1) << "Dilation other than (1, 1) not supported"; arm_compute::PoolingLayerInfo pool_info = arm_compute::PoolingLayerInfo(pool_type, arm_compute::Size2D(pool_size_h, pool_size_w), @@ -469,7 +470,7 @@ class ACLRuntime : public JSONRuntimeBase { } else if (node.GetOpName() == "nn.global_avg_pool2d") { pool_type = arm_compute::PoolingType::AVG; } else { - LOG(FATAL) << "Pooling type not supported"; + TVM_FFI_THROW(InternalError) << "Pooling type not supported"; } arm_compute::PoolingLayerInfo pool_info = @@ -531,7 +532,7 @@ class ACLRuntime : public JSONRuntimeBase { layer->outputs.push_back( MakeACLTensorFromJSONNode(node, &node.GetInputs()[6], &node.GetInputs()[7])); } else { - LOG(FATAL) << "Unsupported form of add op: " + op_name; + TVM_FFI_THROW(InternalError) << "Unsupported form of add op: " + op_name; } auto f = std::make_shared(); @@ -579,8 +580,9 @@ class ACLRuntime : public JSONRuntimeBase { CachedLayer layer_; #else void Run() override { - LOG(FATAL) << "Cannot call run on Arm Compute Library module without runtime enabled. " - << "Please build with USE_ARM_COMPUTE_LIB_GRAPH_EXECUTOR."; + TVM_FFI_THROW(InternalError) + << "Cannot call run on Arm Compute Library module without runtime enabled. " + << "Please build with USE_ARM_COMPUTE_LIB_GRAPH_EXECUTOR."; } void BuildEngine() { diff --git a/src/runtime/contrib/arm_compute_lib/acl_utils.cc b/src/runtime/contrib/arm_compute_lib/acl_utils.cc index d8fcaf9d971b..323a8bcdc658 100644 --- a/src/runtime/contrib/arm_compute_lib/acl_utils.cc +++ b/src/runtime/contrib/arm_compute_lib/acl_utils.cc @@ -35,7 +35,7 @@ namespace contrib { using JSONGraphNode = tvm::runtime::json::JSONGraphNode; void CheckACLError(const arm_compute::Status& status) { - ICHECK(status.error_code() == arm_compute::ErrorCode::OK) + TVM_FFI_ICHECK(status.error_code() == arm_compute::ErrorCode::OK) << "ACL: " << status.error_description(); } @@ -72,7 +72,7 @@ arm_compute::TensorInfo MakeACLTensorInfo(const std::vector& shape, if (scale != nullptr && offset != nullptr) { std::vector scale_data = GetVectorFromDLTensor(scale); std::vector offset_data = GetVectorFromDLTensor(offset); - ICHECK(scale_data.size() == 1 && offset_data.size() == 1) + TVM_FFI_ICHECK(scale_data.size() == 1 && offset_data.size() == 1) << "Currently only per-layer quantization is supported in the Arm Compute Library runtime."; arm_compute::QuantizationInfo qinfo(scale_data[0], offset_data[0]); info.set_quantization_info(qinfo); @@ -114,7 +114,7 @@ arm_compute::PadStrideInfo MakeACLPadStride(const ffi::Array& pad, pad_2 = static_cast(pad[0]); pad_3 = static_cast(pad[2]); } else { - LOG(FATAL) << "Unsupported padding dimensions"; + TVM_FFI_THROW(InternalError) << "Unsupported padding dimensions"; } if (ceil_mode) { @@ -135,7 +135,7 @@ arm_compute::DataType MakeACLDataType(const DLDataType& data_type) { } else if (data_type.code == DLDataTypeCode::kDLInt && data_type.bits == 32) { return arm_compute::DataType::S32; } else { - LOG(FATAL) << "Datatype " << data_type << " unsupported by ACL runtime"; + TVM_FFI_THROW(InternalError) << "Datatype " << data_type << " unsupported by ACL runtime"; } } @@ -144,14 +144,15 @@ arm_compute::ActivationLayerInfo MakeACLActivationInfo(const std::string& activa if (activation_type == "relu") { act_func = arm_compute::ActivationLayerInfo::ActivationFunction::RELU; } else { - LOG(FATAL) << "Activation " << activation_type << " unsupported by ACL runtime"; + TVM_FFI_THROW(InternalError) << "Activation " << activation_type + << " unsupported by ACL runtime"; } return {act_func}; } template std::vector GetVectorFromDLTensor(const DLTensor* tensor) { - ICHECK(tensor) << "Cannot convert a nullptr"; + TVM_FFI_ICHECK(tensor) << "Cannot convert a nullptr"; int len = 1; for (int i = 0; i < tensor->ndim; i++) { len *= tensor->shape[i]; diff --git a/src/runtime/contrib/bnns/bnns_json_runtime.cc b/src/runtime/contrib/bnns/bnns_json_runtime.cc index 0364e1ea57cc..125fa6ae8ce7 100644 --- a/src/runtime/contrib/bnns/bnns_json_runtime.cc +++ b/src/runtime/contrib/bnns/bnns_json_runtime.cc @@ -94,7 +94,7 @@ class BNNSJSONRuntime : public JSONRuntimeBase { const char* kind() const override { return "bnns_json"; } void Init(const ffi::Array& consts) override { - ICHECK_EQ(consts.size(), const_idx_.size()) + TVM_FFI_ICHECK_EQ(consts.size(), const_idx_.size()) << "The number of input constants must match the number of required."; SetupConstants(consts); @@ -163,7 +163,7 @@ class BNNSJSONRuntime : public JSONRuntimeBase { for (size_t nid = 0; nid < nodes_.size(); ++nid) { const auto& node = nodes_[nid]; if (node.GetOpType() == "kernel") { - ICHECK_EQ(node.GetOpType(), "kernel"); + TVM_FFI_ICHECK_EQ(node.GetOpType(), "kernel"); auto op_name = node.GetOpName(); if ("nn.conv2d" == op_name) { Conv2d(nid); @@ -196,7 +196,7 @@ class BNNSJSONRuntime : public JSONRuntimeBase { } else if ("nn.global_avg_pool2d" == op_name) { Pooling(nid, true, true); } else { - LOG(FATAL) << "Unsupported op: " << op_name; + TVM_FFI_THROW(InternalError) << "Unsupported op: " << op_name; } } } @@ -205,7 +205,7 @@ class BNNSJSONRuntime : public JSONRuntimeBase { // Get BNNS tensor. std::shared_ptr GetBNNSTensor(const JSONGraphNodeEntry& entry) { auto eid = EntryID(entry); - ICHECK(eid < tensors_eid_.size()); + TVM_FFI_ICHECK(eid < tensors_eid_.size()); return tensors_eid_[eid]; } @@ -284,7 +284,8 @@ class BNNSJSONRuntime : public JSONRuntimeBase { for (int i = 0; i < params.size(); i++) { auto common_filter_param = getCommonFilterParams(); filters[i] = BNNSFilterCreateLayerConvolution(¶ms[i], &common_filter_param); - ICHECK(filters[i]) << "BNNS primitive was not created. Unsupported attributes configuration"; + TVM_FFI_ICHECK(filters[i]) + << "BNNS primitive was not created. Unsupported attributes configuration"; } primitives_.emplace_back(std::make_shared(filters, src_view, dst_view)); @@ -331,7 +332,8 @@ class BNNSJSONRuntime : public JSONRuntimeBase { auto common_filter_param = getCommonFilterParams(); auto filter = BNNSFilterCreateLayerFullyConnected(&layerParameters, &common_filter_param); - ICHECK(filter) << "BNNS primitive was not created. Unsupported attributes configuration"; + TVM_FFI_ICHECK(filter) + << "BNNS primitive was not created. Unsupported attributes configuration"; std::vector filters = {filter}; primitives_.emplace_back(std::make_shared(filters, src_view, dst_view)); } @@ -374,7 +376,8 @@ class BNNSJSONRuntime : public JSONRuntimeBase { auto common_filter_param = getCommonFilterParams(); auto filter = BNNSFilterCreateLayerBroadcastMatMul(&layerParameters, &common_filter_param); - ICHECK(filter) << "BNNS primitive was not created. Unsupported attributes configuration"; + TVM_FFI_ICHECK(filter) + << "BNNS primitive was not created. Unsupported attributes configuration"; std::vector filters{filter}; if (a_is_weighted || b_is_weighted) { @@ -409,8 +412,8 @@ class BNNSJSONRuntime : public JSONRuntimeBase { auto dst_view = TView::as_is(dst_t); size_t src_rank = Tensor::getRank(src_view.get_bnns_view()); size_t dst_rank = Tensor::getRank(dst_view.get_bnns_view()); - ICHECK_EQ(src_rank, dst_rank); - ICHECK_LE(src_rank, 4); + TVM_FFI_ICHECK_EQ(src_rank, dst_rank); + TVM_FFI_ICHECK_LE(src_rank, 4); if (src_rank < 4) { src_view = src_view.unsqueeze(4); dst_view = dst_view.unsqueeze(4); @@ -443,7 +446,8 @@ class BNNSJSONRuntime : public JSONRuntimeBase { auto common_filter_param = getCommonFilterParams(); auto filter = BNNSFilterCreateLayerNormalization(filter_type, &layerParameters, &common_filter_param); - ICHECK(filter) << "BNNS primitive was not created. Unsupported attributes configuration"; + TVM_FFI_ICHECK(filter) + << "BNNS primitive was not created. Unsupported attributes configuration"; std::vector filters{filter}; primitives_.emplace_back(std::make_shared(filters, src_view, dst_view)); @@ -463,8 +467,8 @@ class BNNSJSONRuntime : public JSONRuntimeBase { auto dst_view = TView::as_is(dst_t); size_t src_rank = Tensor::getRank(src_view.get_bnns_view()); size_t dst_rank = Tensor::getRank(dst_view.get_bnns_view()); - ICHECK_EQ(src_rank, dst_rank); - ICHECK_LE(src_rank, 4); + TVM_FFI_ICHECK_EQ(src_rank, dst_rank); + TVM_FFI_ICHECK_LE(src_rank, 4); if (src_rank < 4) { src_view = src_view.unsqueeze(4); dst_view = dst_view.unsqueeze(4); @@ -515,7 +519,8 @@ class BNNSJSONRuntime : public JSONRuntimeBase { auto common_filter_param = getCommonFilterParams(); auto filter = BNNSFilterCreateLayerPooling(&layerParameters, &common_filter_param); - ICHECK(filter) << "BNNS primitive was not created. Unsupported attributes configuration"; + TVM_FFI_ICHECK(filter) + << "BNNS primitive was not created. Unsupported attributes configuration"; std::vector filters{filter}; primitives_.emplace_back(std::make_shared(filters, src_view, dst_view)); @@ -536,7 +541,7 @@ class BNNSJSONRuntime : public JSONRuntimeBase { if (dl_dtype.bits == 16) return BNNSDataTypeUInt16; if (dl_dtype.bits == 8) return BNNSDataTypeUInt8; } - LOG(FATAL) << "Unsupported data type for BNNS runtime"; + TVM_FFI_THROW(InternalError) << "Unsupported data type for BNNS runtime"; } BNNSFilterParameters getCommonFilterParams() { diff --git a/src/runtime/contrib/bnns/bnns_wrp.h b/src/runtime/contrib/bnns/bnns_wrp.h index 1997e0a84d71..bb729f1efde1 100644 --- a/src/runtime/contrib/bnns/bnns_wrp.h +++ b/src/runtime/contrib/bnns/bnns_wrp.h @@ -60,7 +60,7 @@ class Tensor { Tensor(Shape shape, Dtype dtype, void* hdl) { auto rank = shape.size(); - ICHECK(rank < BNNS_MAX_TENSOR_DIMENSION); + TVM_FFI_ICHECK(rank < BNNS_MAX_TENSOR_DIMENSION); desc_ = {BNNSTensorFlags(0), getPlainLayout(rank), @@ -91,7 +91,7 @@ class Tensor { } const size_t buff_size = getSize(desc_) * getElementSize(desc_); desc_.data = default_alloc(buff_size); - ICHECK(desc_.data); + TVM_FFI_ICHECK(desc_.data); is_external_data = false; } @@ -110,7 +110,7 @@ class Tensor { const BNNSTensorDescriptor& get_desc() const { return desc_; } static BNNSDataLayout getPlainLayout(size_t rank) { - ICHECK(rank <= BNNS_MAX_TENSOR_DIMENSION); + TVM_FFI_ICHECK(rank <= BNNS_MAX_TENSOR_DIMENSION); return static_cast((rank << 16) | 0x8001); } @@ -201,9 +201,9 @@ class TView { TView res = *this; size_t unsqueezed_shape[BNNS_MAX_TENSOR_DIMENSION] = {}; size_t unsqueezed_rank = axes.size() + rank; - ICHECK_LE(unsqueezed_rank, BNNS_MAX_TENSOR_DIMENSION); + TVM_FFI_ICHECK_LE(unsqueezed_rank, BNNS_MAX_TENSOR_DIMENSION); for (const auto& axis : axes) { - ICHECK_LT(axis, unsqueezed_rank); + TVM_FFI_ICHECK_LT(axis, unsqueezed_rank); unsqueezed_shape[axis] = 1; } for (int i = 0, orig_idx = 0; i < unsqueezed_rank; ++i) { @@ -217,9 +217,9 @@ class TView { /** Unsqueeze tensor to a new rank */ TView unsqueeze(size_t new_rank) const { - ICHECK_LE(new_rank, BNNS_MAX_TENSOR_DIMENSION); + TVM_FFI_ICHECK_LE(new_rank, BNNS_MAX_TENSOR_DIMENSION); auto rank = Tensor::getRank(view_desc_); - ICHECK_GT(new_rank, rank); + TVM_FFI_ICHECK_GT(new_rank, rank); std::vector axes(new_rank - rank); std::iota(axes.begin(), axes.end(), rank); return expand_dims(axes); @@ -227,7 +227,7 @@ class TView { /** Construct new TView with specified layout if it applicable */ TView with_layout(BNNSDataLayout layout) const { - ICHECK_EQ(Tensor::getRank(view_desc_), Tensor::getRank(layout)); + TVM_FFI_ICHECK_EQ(Tensor::getRank(view_desc_), Tensor::getRank(layout)); TView res = *this; res.view_desc_.layout = layout; @@ -236,7 +236,7 @@ class TView { /** Construct party TView by splitting original TView into num parts */ TView party_split_n(size_t num) const { - ICHECK_EQ(party_size_, 1); + TVM_FFI_ICHECK_EQ(party_size_, 1); TView res = *this; size_t rank = Tensor::getRank(view_desc_); @@ -255,7 +255,7 @@ class TView { /** Construct party TView by duplicating original TView num times */ TView party_duplicate_n(size_t num) const { - ICHECK_EQ(party_size_, 1); + TVM_FFI_ICHECK_EQ(party_size_, 1); TView res = *this; res.party_size_ = num; @@ -275,7 +275,7 @@ class TView { /** Return party element by index */ TView operator[](size_t i) const { - ICHECK_LT(i, party_size_); + TVM_FFI_ICHECK_LT(i, party_size_); TView res = *this; res.party_size_ = 1; @@ -328,7 +328,7 @@ class Primitive { /** Execute primitive with using specified src/dst */ void execute() { auto res = TVMBackendParallelLaunch(run_task, this, filters.size()); - ICHECK_EQ(res, 0) << "BNNS runtime. Primitive was not executed properly"; + TVM_FFI_ICHECK_EQ(res, 0) << "BNNS runtime. Primitive was not executed properly"; } private: diff --git a/src/runtime/contrib/cblas/cblas.cc b/src/runtime/contrib/cblas/cblas.cc index 85899b64f480..a91db72e5dab 100644 --- a/src/runtime/contrib/cblas/cblas.cc +++ b/src/runtime/contrib/cblas/cblas.cc @@ -127,36 +127,38 @@ struct CblasDgemmBatchIterativeOp { TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() - .def_packed("tvm.contrib.cblas.matmul", - [](ffi::PackedArgs args, ffi::Any* ret) { - auto A = args[0].cast(); - ICHECK(TypeMatch(A->dtype, kDLFloat, 32) || TypeMatch(A->dtype, kDLFloat, 64)); + .def_packed( + "tvm.contrib.cblas.matmul", + [](ffi::PackedArgs args, ffi::Any* ret) { + auto A = args[0].cast(); + TVM_FFI_ICHECK(TypeMatch(A->dtype, kDLFloat, 32) || TypeMatch(A->dtype, kDLFloat, 64)); - if (TypeMatch(A->dtype, kDLFloat, 32)) - CallGemm(args, ret, CblasSgemmOp()); - else - CallGemm(args, ret, CblasDgemmOp()); - }) - .def_packed("tvm.contrib.cblas.batch_matmul", - [](ffi::PackedArgs args, ffi::Any* ret) { - auto A = args[0].cast(); - ICHECK(TypeMatch(A->dtype, kDLFloat, 32) || TypeMatch(A->dtype, kDLFloat, 64)); - if (TypeMatch(A->dtype, kDLFloat, 32)) { - CallBatchGemm(args, ret, CblasSgemmBatchOp()); - } else { - CallBatchGemm(args, ret, CblasDgemmBatchOp()); - } - }) - .def_packed("tvm.contrib.cblas.batch_matmul_iterative", - [](ffi::PackedArgs args, ffi::Any* ret) { - auto A = args[0].cast(); - ICHECK(TypeMatch(A->dtype, kDLFloat, 32) || TypeMatch(A->dtype, kDLFloat, 64)); - if (TypeMatch(A->dtype, kDLFloat, 32)) { - CallBatchGemm(args, ret, CblasSgemmBatchIterativeOp()); - } else { - CallBatchGemm(args, ret, CblasDgemmBatchIterativeOp()); - } - }); + if (TypeMatch(A->dtype, kDLFloat, 32)) + CallGemm(args, ret, CblasSgemmOp()); + else + CallGemm(args, ret, CblasDgemmOp()); + }) + .def_packed( + "tvm.contrib.cblas.batch_matmul", + [](ffi::PackedArgs args, ffi::Any* ret) { + auto A = args[0].cast(); + TVM_FFI_ICHECK(TypeMatch(A->dtype, kDLFloat, 32) || TypeMatch(A->dtype, kDLFloat, 64)); + if (TypeMatch(A->dtype, kDLFloat, 32)) { + CallBatchGemm(args, ret, CblasSgemmBatchOp()); + } else { + CallBatchGemm(args, ret, CblasDgemmBatchOp()); + } + }) + .def_packed( + "tvm.contrib.cblas.batch_matmul_iterative", [](ffi::PackedArgs args, ffi::Any* ret) { + auto A = args[0].cast(); + TVM_FFI_ICHECK(TypeMatch(A->dtype, kDLFloat, 32) || TypeMatch(A->dtype, kDLFloat, 64)); + if (TypeMatch(A->dtype, kDLFloat, 32)) { + CallBatchGemm(args, ret, CblasSgemmBatchIterativeOp()); + } else { + CallBatchGemm(args, ret, CblasDgemmBatchIterativeOp()); + } + }); } } // namespace contrib } // namespace tvm diff --git a/src/runtime/contrib/cblas/dnnl_blas.cc b/src/runtime/contrib/cblas/dnnl_blas.cc index 9862a37301d3..d6a9baa21bc8 100644 --- a/src/runtime/contrib/cblas/dnnl_blas.cc +++ b/src/runtime/contrib/cblas/dnnl_blas.cc @@ -51,7 +51,7 @@ TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def_packed("tvm.contrib.dnnl.matmul", [](ffi::PackedArgs args, ffi::Any* ret) { auto A = args[0].cast(); - ICHECK(TypeMatch(A->dtype, kDLFloat, 32)); + TVM_FFI_ICHECK(TypeMatch(A->dtype, kDLFloat, 32)); CallGemm(args, ret, DNNLSgemmOp()); }); } diff --git a/src/runtime/contrib/cblas/gemm_common.h b/src/runtime/contrib/cblas/gemm_common.h index a44cf1b365ec..b276a3d167cd 100644 --- a/src/runtime/contrib/cblas/gemm_common.h +++ b/src/runtime/contrib/cblas/gemm_common.h @@ -80,23 +80,23 @@ inline void CallGemm(ffi::PackedArgs args, ffi::Any* ret, TGemmOp op) { bool transa = args[3].cast(); bool transb = args[4].cast(); int bit_depth = sizeof(typename TGemmOp::TDatatype) * 8; - ICHECK_EQ(A->ndim, 2); - ICHECK_EQ(B->ndim, 2); - ICHECK_EQ(C->ndim, 2); + TVM_FFI_ICHECK_EQ(A->ndim, 2); + TVM_FFI_ICHECK_EQ(B->ndim, 2); + TVM_FFI_ICHECK_EQ(C->ndim, 2); - ICHECK_EQ(ElementStride(A), 1); - ICHECK_EQ(ElementStride(B), 1); - ICHECK_EQ(ElementStride(C), 1); + TVM_FFI_ICHECK_EQ(ElementStride(A), 1); + TVM_FFI_ICHECK_EQ(ElementStride(B), 1); + TVM_FFI_ICHECK_EQ(ElementStride(C), 1); // C can never be transposed. - ICHECK(!IsInPlaceTransposed(C)); + TVM_FFI_ICHECK(!IsInPlaceTransposed(C)); // Reversed strides indicates an in-place transpose operation. transa = IsInPlaceTransposed(A) ? !transa : transa; transb = IsInPlaceTransposed(B) ? !transb : transb; - ICHECK(TypeMatch(B->dtype, kDLFloat, bit_depth)); - ICHECK(TypeMatch(C->dtype, kDLFloat, bit_depth)); + TVM_FFI_ICHECK(TypeMatch(B->dtype, kDLFloat, bit_depth)); + TVM_FFI_ICHECK(TypeMatch(C->dtype, kDLFloat, bit_depth)); double alpha = args.size() > 5 ? args[5].cast() : 1.0; double beta = args.size() > 6 ? args[6].cast() : 0.0; op(transb, transa, ColumnCount(B, transb), RowCount(A, transa), ColumnCount(A, transa), @@ -127,24 +127,24 @@ inline void CallU8S8S32Gemm(ffi::PackedArgs args, ffi::Any* ret, TGemmOp op) { int offset_c[1]; offset_c[0] = 0; - ICHECK_EQ(A->ndim, 2); - ICHECK_EQ(B->ndim, 2); - ICHECK_EQ(C->ndim, 2); + TVM_FFI_ICHECK_EQ(A->ndim, 2); + TVM_FFI_ICHECK_EQ(B->ndim, 2); + TVM_FFI_ICHECK_EQ(C->ndim, 2); - ICHECK_EQ(ElementStride(A), 1); - ICHECK_EQ(ElementStride(B), 1); - ICHECK_EQ(ElementStride(C), 1); + TVM_FFI_ICHECK_EQ(ElementStride(A), 1); + TVM_FFI_ICHECK_EQ(ElementStride(B), 1); + TVM_FFI_ICHECK_EQ(ElementStride(C), 1); // C can never be transposed. - ICHECK(!IsInPlaceTransposed(C)); + TVM_FFI_ICHECK(!IsInPlaceTransposed(C)); // Reversed strides indicates an in-place transpose operation. transa = IsInPlaceTransposed(A) ? !transa : transa; transb = IsInPlaceTransposed(B) ? !transb : transb; - ICHECK(TypeMatch(A->dtype, kDLUInt, 8)); - ICHECK(TypeMatch(B->dtype, kDLInt, 8)); - ICHECK(TypeMatch(C->dtype, kDLInt, 32)); + TVM_FFI_ICHECK(TypeMatch(A->dtype, kDLUInt, 8)); + TVM_FFI_ICHECK(TypeMatch(B->dtype, kDLInt, 8)); + TVM_FFI_ICHECK(TypeMatch(C->dtype, kDLInt, 32)); double alpha = args.size() > 5 ? args[5].cast() : 1.0; double beta = args.size() > 6 ? args[6].cast() : 0.0; op(transb, transa, ColumnCount(B, transb), RowCount(A, transa), ColumnCount(A, transa), @@ -191,23 +191,23 @@ inline void CallBatchGemm(ffi::PackedArgs args, ffi::Any* ret, TBatchGemmOp op) int bit_depth = sizeof(DType) * 8; - ICHECK_EQ(A->ndim, 3); - ICHECK_EQ(B->ndim, 3); - ICHECK_EQ(C->ndim, 3); + TVM_FFI_ICHECK_EQ(A->ndim, 3); + TVM_FFI_ICHECK_EQ(B->ndim, 3); + TVM_FFI_ICHECK_EQ(C->ndim, 3); int batch_size = BatchCount3D(C); - ICHECK_EQ(ElementStride(A), 1); - ICHECK_EQ(ElementStride(B), 1); - ICHECK_EQ(ElementStride(C), 1); + TVM_FFI_ICHECK_EQ(ElementStride(A), 1); + TVM_FFI_ICHECK_EQ(ElementStride(B), 1); + TVM_FFI_ICHECK_EQ(ElementStride(C), 1); // C can never be transposed. - ICHECK(!IsInPlaceTransposed3D(C)); + TVM_FFI_ICHECK(!IsInPlaceTransposed3D(C)); // Reversed strides indicates an in-place transpose operation. transa = IsInPlaceTransposed3D(A) ? !transa : transa; transb = IsInPlaceTransposed3D(B) ? !transb : transb; - ICHECK(TypeMatch(B->dtype, kDLFloat, bit_depth)); - ICHECK(TypeMatch(C->dtype, kDLFloat, bit_depth)); + TVM_FFI_ICHECK(TypeMatch(B->dtype, kDLFloat, bit_depth)); + TVM_FFI_ICHECK(TypeMatch(C->dtype, kDLFloat, bit_depth)); double alpha = args.size() > 5 ? args[5].cast() : 1.0; double beta = args.size() > 6 ? args[6].cast() : 0.0; @@ -226,8 +226,8 @@ inline void CallBatchGemm(ffi::PackedArgs args, ffi::Any* ret, TBatchGemmOp op) B_stride = 0; } } else { - ICHECK_EQ(batch_size_a, batch_size); - ICHECK_EQ(batch_size_b, batch_size); + TVM_FFI_ICHECK_EQ(batch_size_a, batch_size); + TVM_FFI_ICHECK_EQ(batch_size_b, batch_size); } DType* A_data = reinterpret_cast(static_cast(A->data) + diff --git a/src/runtime/contrib/cblas/mkl.cc b/src/runtime/contrib/cblas/mkl.cc index be8db227e554..59783134157c 100644 --- a/src/runtime/contrib/cblas/mkl.cc +++ b/src/runtime/contrib/cblas/mkl.cc @@ -42,7 +42,7 @@ inline CBLAS_TRANSPOSE MKLBooleanToTranspose(bool trans) { inline CBLAS_OFFSET MKLStringToOffset(const std::string offset_type) { if (offset_type != "CblasFixOffset" && offset_type != "CblasColOffset" && offset_type != "CblasRowOffset") { - LOG(FATAL) << "Unrecognized offset_type " << offset_type; + TVM_FFI_THROW(InternalError) << "Unrecognized offset_type " << offset_type; } if (offset_type == "CblasFixOffset") { return CblasFixOffset; @@ -159,7 +159,7 @@ TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def_packed("tvm.contrib.mkl.matmul", [](ffi::PackedArgs args, ffi::Any* ret) { auto A = args[0].cast(); - ICHECK(TypeMatch(A->dtype, kDLFloat, 32) || TypeMatch(A->dtype, kDLFloat, 64)); + TVM_FFI_ICHECK(TypeMatch(A->dtype, kDLFloat, 32) || TypeMatch(A->dtype, kDLFloat, 64)); if (TypeMatch(A->dtype, kDLFloat, 32)) CallGemm(args, ret, MKLSgemmOp()); @@ -177,31 +177,33 @@ TVM_FFI_STATIC_INIT_BLOCK() { auto A = args[0].cast(); auto B = args[1].cast(); auto C = args[2].cast(); - ICHECK(TypeMatch(A->dtype, kDLUInt, 8) && TypeMatch(B->dtype, kDLInt, 8) && - TypeMatch(C->dtype, kDLInt, 32)); + TVM_FFI_ICHECK(TypeMatch(A->dtype, kDLUInt, 8) && + TypeMatch(B->dtype, kDLInt, 8) && + TypeMatch(C->dtype, kDLInt, 32)); CallU8S8S32Gemm(args, ret, MKLGemmU8S8S32Op()); }) - .def_packed("tvm.contrib.mkl.batch_matmul", - [](ffi::PackedArgs args, ffi::Any* ret) { - auto A = args[0].cast(); - ICHECK(TypeMatch(A->dtype, kDLFloat, 32) || TypeMatch(A->dtype, kDLFloat, 64)); - if (TypeMatch(A->dtype, kDLFloat, 32)) { - CallBatchGemm(args, ret, MKLSgemmBatchOp()); - } else { - CallBatchGemm(args, ret, MKLDgemmBatchOp()); - } - }) - .def_packed("tvm.contrib.mkl.batch_matmul_iterative", - [](ffi::PackedArgs args, ffi::Any* ret) { - auto A = args[0].cast(); - ICHECK(TypeMatch(A->dtype, kDLFloat, 32) || TypeMatch(A->dtype, kDLFloat, 64)); - if (TypeMatch(A->dtype, kDLFloat, 32)) { - CallBatchGemm(args, ret, MKLSgemmBatchIterativeOp()); - } else { - CallBatchGemm(args, ret, MKLDgemmBatchIterativeOp()); - } - }); + .def_packed( + "tvm.contrib.mkl.batch_matmul", + [](ffi::PackedArgs args, ffi::Any* ret) { + auto A = args[0].cast(); + TVM_FFI_ICHECK(TypeMatch(A->dtype, kDLFloat, 32) || TypeMatch(A->dtype, kDLFloat, 64)); + if (TypeMatch(A->dtype, kDLFloat, 32)) { + CallBatchGemm(args, ret, MKLSgemmBatchOp()); + } else { + CallBatchGemm(args, ret, MKLDgemmBatchOp()); + } + }) + .def_packed( + "tvm.contrib.mkl.batch_matmul_iterative", [](ffi::PackedArgs args, ffi::Any* ret) { + auto A = args[0].cast(); + TVM_FFI_ICHECK(TypeMatch(A->dtype, kDLFloat, 32) || TypeMatch(A->dtype, kDLFloat, 64)); + if (TypeMatch(A->dtype, kDLFloat, 32)) { + CallBatchGemm(args, ret, MKLSgemmBatchIterativeOp()); + } else { + CallBatchGemm(args, ret, MKLDgemmBatchIterativeOp()); + } + }); } } // namespace contrib } // namespace tvm diff --git a/src/runtime/contrib/clml/clml_memory_planner.cc b/src/runtime/contrib/clml/clml_memory_planner.cc index 9e61f557f48e..281db7436fc6 100644 --- a/src/runtime/contrib/clml/clml_memory_planner.cc +++ b/src/runtime/contrib/clml/clml_memory_planner.cc @@ -257,7 +257,7 @@ void ReleaseDDRMemory(cl_mem memptr) { if (0 == cws->ddr_global_pool[memptr].second) { LOG_MEM << "Release DDR mem from global pool"; result = clReleaseMemObject(memptr); - ICHECK(result == CL_SUCCESS) << "clReleaseMemObject:" << result; + TVM_FFI_ICHECK(result == CL_SUCCESS) << "clReleaseMemObject:" << result; cws->ddr_global_pool.erase(memptr); } } diff --git a/src/runtime/contrib/clml/clml_runtime.cc b/src/runtime/contrib/clml/clml_runtime.cc index a26055c2ca21..f6a872c07364 100644 --- a/src/runtime/contrib/clml/clml_runtime.cc +++ b/src/runtime/contrib/clml/clml_runtime.cc @@ -65,15 +65,15 @@ CLMLWorkspace::CLMLWorkspace() { // Print extensions size_t reqd_size = 0; result = clGetDeviceInfo(device_id, CL_DEVICE_EXTENSIONS, 0, nullptr, &reqd_size); - ICHECK(reqd_size > 0u && result == CL_SUCCESS) << "clGetDeviceInfo:" << result; + TVM_FFI_ICHECK(reqd_size > 0u && result == CL_SUCCESS) << "clGetDeviceInfo:" << result; std::vector extn_buf(reqd_size); result = clGetDeviceInfo(device_id, CL_DEVICE_EXTENSIONS, reqd_size, extn_buf.data(), nullptr); - ICHECK(result == CL_SUCCESS) << "clGetDeviceInfo:" << result; + TVM_FFI_ICHECK(result == CL_SUCCESS) << "clGetDeviceInfo:" << result; std::string extensions(extn_buf.data()); LOG_CLML << "OpenCL Extensions:" << extensions; if (extensions.find("cl_qcom_ml_ops") == std::string::npos) { - LOG(FATAL) << "CLML Runtime Init: Qualcomm extn not present.\n"; + TVM_FFI_THROW(InternalError) << "CLML Runtime Init: Qualcomm extn not present.\n"; return; } if (getenv("CLML_DISABLE_RECORDABLE_QUEUE")) { @@ -89,8 +89,8 @@ CLMLWorkspace::CLMLWorkspace() { if (is_on_chip_memory) { result = clGetDeviceInfo(device_id, CL_DEVICE_ONCHIP_GLOBAL_MEM_SIZE_QCOM, sizeof(onchip_mem_size), &onchip_mem_size, nullptr); - ICHECK(result == CL_SUCCESS) << "clGetDeviceInfo(CL_DEVICE_ONCHIP_GLOBAL_MEM_SIZE_QCOM):" - << result; + TVM_FFI_ICHECK(result == CL_SUCCESS) + << "clGetDeviceInfo(CL_DEVICE_ONCHIP_GLOBAL_MEM_SIZE_QCOM):" << result; LOG_CLML << "On chip memory size:" << onchip_mem_size; } @@ -100,12 +100,12 @@ CLMLWorkspace::CLMLWorkspace() { cl_int minorVersions[MAX_VERSIONS]; cl_uint numVersions = 0; result = clQueryMLInterfaceVersionsQCOM(nullptr, nullptr, 0, &numVersions); - ICHECK(result == CL_SUCCESS) << "clQueryMLInterfaceVersionsQCOM:" << result; - ICHECK(numVersions > 0u); - ICHECK(numVersions <= MAX_VERSIONS); + TVM_FFI_ICHECK(result == CL_SUCCESS) << "clQueryMLInterfaceVersionsQCOM:" << result; + TVM_FFI_ICHECK(numVersions > 0u); + TVM_FFI_ICHECK(numVersions <= MAX_VERSIONS); result = clQueryMLInterfaceVersionsQCOM(majorVersions, minorVersions, numVersions, nullptr); - ICHECK(result == CL_SUCCESS) << "clQueryMLInterfaceVersionsQCOM:" << result; + TVM_FFI_ICHECK(result == CL_SUCCESS) << "clQueryMLInterfaceVersionsQCOM:" << result; target_major = majorVersions[numVersions - 1]; target_minor = minorVersions[numVersions - 1]; @@ -123,10 +123,10 @@ CLMLWorkspace::CLMLWorkspace() { clGetMLInterfaceQCOM(&h_ClmlIntf, target_major, target_minor); - ICHECK(nullptr != h_ClmlIntf) << "Couldn't get API interface, target is not supported." - << "Compiled version: " << CL_QCOM_ML_OPS_H_MAJOR_VERSION << "." - << CL_QCOM_ML_OPS_H_MINOR_VERSION - << "Target Version:" << target_major << "." << target_minor; + TVM_FFI_ICHECK(nullptr != h_ClmlIntf) + << "Couldn't get API interface, target is not supported." + << "Compiled version: " << CL_QCOM_ML_OPS_H_MAJOR_VERSION << "." + << CL_QCOM_ML_OPS_H_MINOR_VERSION << "Target Version:" << target_major << "." << target_minor; char* tune_flag; if ((tune_flag = getenv("CLML_IS_TUNING_RUN"))) @@ -171,7 +171,7 @@ class CLMLRuntime : public JSONRuntimeBase { ReleaseDDRMemory(tensor_desc->memory); } else { result = clReleaseMemObject(tensor_desc->memory); - ICHECK(result == CL_SUCCESS) << "clReleaseMemObject:" << result; + TVM_FFI_ICHECK(result == CL_SUCCESS) << "clReleaseMemObject:" << result; } } for (size_t i = 0; i < this->layer_.function.size(); ++i) { @@ -207,7 +207,7 @@ class CLMLRuntime : public JSONRuntimeBase { * \param consts The constant params from compiled model. */ void Init(const ffi::Array& consts) override { - ICHECK_EQ(consts.size(), const_idx_.size()) + TVM_FFI_ICHECK_EQ(consts.size(), const_idx_.size()) << "The number of input constants must match the number of required."; SetupConstants(consts); @@ -227,10 +227,10 @@ class CLMLRuntime : public JSONRuntimeBase { if (cws->is_recordable_queue) { this->layer_.recordable_queue = clCreateCommandQueue(CLML_CTX, cws->device_id, CL_QUEUE_RECORDABLE_QCOM, &result); - ICHECK(result == CL_SUCCESS) << "clCreateCommandQueue - Recordable:" << result; + TVM_FFI_ICHECK(result == CL_SUCCESS) << "clCreateCommandQueue - Recordable:" << result; this->layer_.recording = clNewRecordingQCOM(this->layer_.recordable_queue, &result); - ICHECK(result == CL_SUCCESS) << "clNewRecordingQCOM:" << result; + TVM_FFI_ICHECK(result == CL_SUCCESS) << "clNewRecordingQCOM:" << result; } // A Tuning run, so create the cache from scratch @@ -270,8 +270,9 @@ class CLMLRuntime : public JSONRuntimeBase { std::string DebugDump(void) override { if (cws->is_recordable_queue) { - LOG(FATAL) << "Debugging over recordable queues is not supported yet. You may disable the " - "same by exporting CLML_DISABLE_RECORDABLE_QUEUE at runtime."; + TVM_FFI_THROW(InternalError) + << "Debugging over recordable queues is not supported yet. You may disable the " + "same by exporting CLML_DISABLE_RECORDABLE_QUEUE at runtime."; } namespace json = ::tvm::ffi::json; cl_command_queue queue = CLML_QUEUE; @@ -722,7 +723,7 @@ class CLMLRuntime : public JSONRuntimeBase { uint32_t eid = EntryID(nid, 0); node_data = data_entry_[eid]->data; usage = CL_TENSOR_USAGE_PARAMETER_QCOM; - ICHECK(CL_TENSOR_USAGE_INVALID_QCOM == this->layer_.storage_map[nid].usage) + TVM_FFI_ICHECK(CL_TENSOR_USAGE_INVALID_QCOM == this->layer_.storage_map[nid].usage) << "Parameter have usage reservation !!!"; } if (CL_TENSOR_USAGE_INVALID_QCOM != this->layer_.storage_map[nid].usage) { @@ -833,7 +834,7 @@ class CLMLRuntime : public JSONRuntimeBase { else if ("nn.batch_matmul" == op_name) CreateBatchMatmulLayer(&layer_, node, nid); else - LOG(FATAL) << "Unsupported op: " << op_name; + TVM_FFI_THROW(InternalError) << "Unsupported op: " << op_name; this->layer_.layer_names.push_back(op_name); // Keep map of function and Node to use in profiling this->layer_.op_node_map.insert({this->layer_.function.back(), std::make_pair(nid, node)}); @@ -893,8 +894,9 @@ class CLMLRuntime : public JSONRuntimeBase { tensor_desc->memory = AllocateDDRTensorMemory(mem_size); alloc_ddr += mem_size; } else { - LOG(FATAL) << "Mem allocation not found on DDR as well as On-Chip nid: " << it->first - << " Type:" << node.GetOpType(); + TVM_FFI_THROW(InternalError) + << "Mem allocation not found on DDR as well as On-Chip nid: " << it->first + << " Type:" << node.GetOpType(); } if (node.GetOpType() == "const") { @@ -958,7 +960,7 @@ class CLMLRuntime : public JSONRuntimeBase { strm.Write(clml_symbol); strm.Write(saved_cache); std::ofstream fs(cws->tuning_file, std::ios::app | std::ios::binary); - ICHECK(!fs.fail()) << "Cannot open " << cws->tuning_file; + TVM_FFI_ICHECK(!fs.fail()) << "Cannot open " << cws->tuning_file; fs.write(tune_str.data(), tune_str.size()); LOG_CLML << "CLML: Tuning cache dumped to:" << cws->tuning_file << " size" << tune_str.length() << " with tuning blob len " << saved_cache.size(); @@ -970,7 +972,7 @@ class CLMLRuntime : public JSONRuntimeBase { } result = clEndRecordingQCOM(this->layer_.recording); - ICHECK(result == CL_SUCCESS) << "clEndRecordingQCOM:" << result; + TVM_FFI_ICHECK(result == CL_SUCCESS) << "clEndRecordingQCOM:" << result; } } @@ -1014,7 +1016,7 @@ class CLMLRuntime : public JSONRuntimeBase { cl_uint groups = static_cast(node.GetAttr("groups")); if (CL_CONVOLUTION_MODE_CONVOLUTION_QCOM == mode) { - ICHECK(groups == 1) << "CLML convolution only supports group size of 1."; + TVM_FFI_ICHECK(groups == 1) << "CLML convolution only supports group size of 1."; } else { groups = 1; // Don't need to pass groups to depthwise } @@ -1024,7 +1026,7 @@ class CLMLRuntime : public JSONRuntimeBase { cl_activation_function_qcom clml_act_type = CL_ACTIVATION_RELU; if (node.HasAttr("activation_type")) { activation_type = std::string(node.GetAttr("activation_type")); - ICHECK(activation_type == "relu" || activation_type == "relu6") + TVM_FFI_ICHECK(activation_type == "relu" || activation_type == "relu6") << "Unknown activation type:" << activation_type; if (activation_type == "relu") { clml_act_type = CL_ACTIVATION_RELU; @@ -1041,7 +1043,7 @@ class CLMLRuntime : public JSONRuntimeBase { size_t num_inputs = inputs.size(); bool has_bias; bool has_bn; - ICHECK(num_inputs >= 2 && num_inputs <= 7) + TVM_FFI_ICHECK(num_inputs >= 2 && num_inputs <= 7) << "Batchnorm fused convolution requires max 7 arguments"; has_bias = (num_inputs == 3) || (num_inputs == 7); has_bn = (num_inputs == 6) || (num_inputs == 7); @@ -1061,7 +1063,7 @@ class CLMLRuntime : public JSONRuntimeBase { desc.num_dimensions = CL_TENSOR_UNUSED_QCOM; CLML_CALL_clCreateMLTensorQCOM(CLML_CTX, nullptr, &desc, CL_TENSOR_USAGE_UNUSED_QCOM, &layer_.unusedTensor); - ICHECK(layer_.unusedTensor) << "clCreateMLTensorQCOM: unusedTensor"; + TVM_FFI_ICHECK(layer_.unusedTensor) << "clCreateMLTensorQCOM: unusedTensor"; bias->tensor = layer_.unusedTensor; } // Output @@ -1154,11 +1156,11 @@ class CLMLRuntime : public JSONRuntimeBase { desc.num_dimensions = CL_TENSOR_UNUSED_QCOM; CLML_CALL_clCreateMLTensorQCOM(CLML_CTX, nullptr, &desc, CL_TENSOR_USAGE_UNUSED_QCOM, &layer_.unusedTensor); - ICHECK(layer_.unusedTensor) << "clCreateMLTensorQCOM: unusedTensor"; + TVM_FFI_ICHECK(layer_.unusedTensor) << "clCreateMLTensorQCOM: unusedTensor"; CLML_CALL(clCreateMLOpActivationForwardQCOM, CLML_CTX, nullptr, &act_desc, input->tensor, layer_.unusedTensor, output->tensor, &op, layer_.tuning_cache); - ICHECK(op) << "Activation Error"; + TVM_FFI_ICHECK(op) << "Activation Error"; layer->function.push_back(op); return; @@ -1210,7 +1212,7 @@ class CLMLRuntime : public JSONRuntimeBase { CLML_CALL(clCreateMLOpBatchNormForwardQCOM, CLML_CTX, opProperties.data(), &bn_desc, input->tensor, bn_mean->tensor, bn_var->tensor, bn_scale->tensor, bn_bias->tensor, output->tensor, &op, layer_.tuning_cache); - ICHECK(op) << "Batchnorm Error"; + TVM_FFI_ICHECK(op) << "Batchnorm Error"; layer->function.push_back(op); return; @@ -1265,11 +1267,11 @@ class CLMLRuntime : public JSONRuntimeBase { desc.num_dimensions = CL_TENSOR_UNUSED_QCOM; CLML_CALL_clCreateMLTensorQCOM(CLML_CTX, nullptr, &desc, CL_TENSOR_USAGE_UNUSED_QCOM, &unusedTensor); - ICHECK(unusedTensor) << "clCreateMLTensorQCOM: unusedTensor"; + TVM_FFI_ICHECK(unusedTensor) << "clCreateMLTensorQCOM: unusedTensor"; CLML_CALL(clCreateMLOpPoolingForwardQCOM, CLML_CTX, nullptr, &pool_desc, input->tensor, unusedTensor, output->tensor, &op, layer_.tuning_cache); - ICHECK(op) << "Pooling Error"; + TVM_FFI_ICHECK(op) << "Pooling Error"; layer->function.push_back(op); return; @@ -1311,11 +1313,11 @@ class CLMLRuntime : public JSONRuntimeBase { desc.num_dimensions = CL_TENSOR_UNUSED_QCOM; CLML_CALL_clCreateMLTensorQCOM(CLML_CTX, nullptr, &desc, CL_TENSOR_USAGE_UNUSED_QCOM, &layer_.unusedTensor); - ICHECK(layer_.unusedTensor) << "clCreateMLTensorQCOM: unusedTensor"; + TVM_FFI_ICHECK(layer_.unusedTensor) << "clCreateMLTensorQCOM: unusedTensor"; CLML_CALL(clCreateMLOpPoolingForwardQCOM, CLML_CTX, nullptr, &pool_desc, input->tensor, layer_.unusedTensor, output->tensor, &op, layer_.tuning_cache); - ICHECK(op) << "Pooling Error"; + TVM_FFI_ICHECK(op) << "Pooling Error"; layer->function.push_back(op); return; @@ -1369,7 +1371,7 @@ class CLMLRuntime : public JSONRuntimeBase { cl_arithmetic_mode}; CLML_CALL(clCreateMLOpSoftmaxQCOM, CLML_CTX, nullptr, &softmax_desc, input->tensor, output->tensor, &op, layer_.tuning_cache); - ICHECK(op) << "SoftMax Error"; + TVM_FFI_ICHECK(op) << "SoftMax Error"; layer->function.push_back(op); return; } @@ -1404,7 +1406,7 @@ class CLMLRuntime : public JSONRuntimeBase { else if (pad_mode == "reflect") clml_pad_mode = CL_PAD_MODE_REFLECT_QCOM; else - LOG(FATAL) << "Padding mode not supported by CLML:" << pad_mode; + TVM_FFI_THROW(InternalError) << "Padding mode not supported by CLML:" << pad_mode; cl_ml_op_pad_desc_qcom pad_desc{ clml_pad_mode, @@ -1414,7 +1416,7 @@ class CLMLRuntime : public JSONRuntimeBase { CLML_CALL(clCreateMLOpPadQCOM, CLML_CTX, nullptr, &pad_desc, input->tensor, output->tensor, &op, layer_.tuning_cache); - ICHECK(op) << "Pad Error"; + TVM_FFI_ICHECK(op) << "Pad Error"; layer->function.push_back(op); return; @@ -1437,7 +1439,7 @@ class CLMLRuntime : public JSONRuntimeBase { CLML_CALL(clCreateMLOpReshapeQCOM, CLML_CTX, nullptr, input->tensor, output->tensor, &op, layer_.tuning_cache); - ICHECK(op) << "Reshape Error"; + TVM_FFI_ICHECK(op) << "Reshape Error"; layer->function.push_back(op); return; @@ -1460,7 +1462,7 @@ class CLMLRuntime : public JSONRuntimeBase { CLML_CALL(clCreateMLOpReshapeQCOM, CLML_CTX, nullptr, input->tensor, output->tensor, &op, layer_.tuning_cache); - ICHECK(op) << "Reshape Error"; + TVM_FFI_ICHECK(op) << "Reshape Error"; layer->function.push_back(op); return; @@ -1493,7 +1495,7 @@ class CLMLRuntime : public JSONRuntimeBase { CLML_CALL(clCreateMLOpConcatQCOM, CLML_CTX, nullptr, &concatDesc, concatInputs, output->tensor, &op, layer_.tuning_cache); - ICHECK(op) << "Concat Error"; + TVM_FFI_ICHECK(op) << "Concat Error"; layer->function.push_back(op); @@ -1548,7 +1550,7 @@ class CLMLRuntime : public JSONRuntimeBase { CLML_CALL(clCreateMLOpFullyConnectedQCOM, CLML_CTX, nullptr, &fc_desc, input->tensor, weight->tensor, bias->tensor, output->tensor, &op, layer_.tuning_cache); - ICHECK(op) << "FC layer Error"; + TVM_FFI_ICHECK(op) << "FC layer Error"; layer->function.push_back(op); } else { cl_gemm_transform_qcom b_transform = CL_GEMM_TRANSFORM_NONE_QCOM; @@ -1565,7 +1567,7 @@ class CLMLRuntime : public JSONRuntimeBase { CLML_CALL(clCreateMLOpGemmQCOM, CLML_CTX, nullptr, &gemmDesc, input->tensor, weight->tensor, output->tensor, &op, layer_.tuning_cache); - ICHECK(op) << "Gemm layer Error"; + TVM_FFI_ICHECK(op) << "Gemm layer Error"; layer->function.push_back(op); if (has_bias) { cl_ml_op_binary_desc_qcom binaryDesc = {CL_TENSOR_OP_ADD_QCOM, @@ -1575,7 +1577,7 @@ class CLMLRuntime : public JSONRuntimeBase { cl_arithmetic_mode}; CLML_CALL(clCreateMLOpBinaryQCOM, CLML_CTX, nullptr, &binaryDesc, bias->tensor, layer_.unusedTensor, output->tensor, &op, layer_.tuning_cache); - ICHECK(op) << "Binary Op Error"; + TVM_FFI_ICHECK(op) << "Binary Op Error"; layer->function.push_back(op); } } @@ -1656,7 +1658,7 @@ class CLMLRuntime : public JSONRuntimeBase { CLML_CALL(clCreateMLOpGemmQCOM, CLML_CTX, nullptr, &gemmDesc, input->tensor, weight->tensor, output->tensor, &op, layer_.tuning_cache); - ICHECK(op) << "BatchMatmul Error"; + TVM_FFI_ICHECK(op) << "BatchMatmul Error"; layer->function.push_back(op); return; @@ -1715,7 +1717,7 @@ class CLMLRuntime : public JSONRuntimeBase { CLML_CALL_clCreateMLOpClipQCOM(CLML_CTX, nullptr, &clip_desc, input->tensor, output->tensor, &op, layer_.tuning_cache); - ICHECK(op) << "Clip Error"; + TVM_FFI_ICHECK(op) << "Clip Error"; layer->function.push_back(op); return; @@ -1753,13 +1755,13 @@ class CLMLRuntime : public JSONRuntimeBase { else if (op_name == "add" || PatternMatch(op_name, "relax.add")) binary_op = CL_TENSOR_OP_ADD_QCOM; else - LOG(FATAL) << "Undefined binary op:" << op_name; + TVM_FFI_THROW(InternalError) << "Undefined binary op:" << op_name; cl_ml_op_binary_desc_qcom add_desc = { binary_op, {{1.0}, CL_FLOAT}, {{1.0}, CL_FLOAT}, {{0.0}, CL_FLOAT}, cl_arithmetic_mode}; LOG(INFO) << "Op name - " << op_name; CLML_CALL(clCreateMLOpBinaryQCOM, CLML_CTX, nullptr, &add_desc, input_a->tensor, input_b->tensor, output->tensor, &op, layer_.tuning_cache); - ICHECK(op) << op_name << " Node Error"; + TVM_FFI_ICHECK(op) << op_name << " Node Error"; layer->function.push_back(op); return; @@ -1785,7 +1787,7 @@ class CLMLRuntime : public JSONRuntimeBase { cl_ml_op_depthtospace_desc_qcom dtos_desc = {block_size, cl_arithmetic_mode}; CLML_CALL(clCreateMLOpDepthToSpaceQCOM, CLML_CTX, nullptr, &dtos_desc, input->tensor, output->tensor, &op, layer_.tuning_cache); - ICHECK(op) << "DepthToSpace Layer Error"; + TVM_FFI_ICHECK(op) << "DepthToSpace Layer Error"; layer->function.push_back(op); return; @@ -1811,7 +1813,7 @@ class CLMLRuntime : public JSONRuntimeBase { cl_ml_op_resize_bilinear_desc_qcom resize_desc = {align_corners, false, cl_arithmetic_mode}; CLML_CALL(clCreateMLOpResizeBilinearQCOM, CLML_CTX, nullptr, &resize_desc, input->tensor, output->tensor, &op, layer_.tuning_cache); - ICHECK(op) << "Resize Layer Error"; + TVM_FFI_ICHECK(op) << "Resize Layer Error"; layer->function.push_back(op); return; @@ -1830,8 +1832,8 @@ class CLMLRuntime : public JSONRuntimeBase { #else void Run() override { - LOG(FATAL) << "Cannot call run on CLML module without runtime enabled. " - << "Please build with USE_CLML_GRAPH_EXECUTOR."; + TVM_FFI_THROW(InternalError) << "Cannot call run on CLML module without runtime enabled. " + << "Please build with USE_CLML_GRAPH_EXECUTOR."; } void BuildEngine() { diff --git a/src/runtime/contrib/clml/clml_runtime.h b/src/runtime/contrib/clml/clml_runtime.h index 716ea4665ea4..a8dfe58bdf6e 100644 --- a/src/runtime/contrib/clml/clml_runtime.h +++ b/src/runtime/contrib/clml/clml_runtime.h @@ -56,7 +56,7 @@ #define CAT(a, b) CAT_I(a, b) #define CLML_CHECK_ERROR(e, API) \ - { ICHECK(e == CL_SUCCESS) << "CLML Error:" #API " code=" << e; } + { TVM_FFI_ICHECK(e == CL_SUCCESS) << "CLML Error:" #API " code=" << e; } #if CL_QCOM_ML_OPS_H_MAJOR_VERSION > 3 #define V4_API(API, ...) \ @@ -64,7 +64,8 @@ ->API(__VA_ARGS__); \ CLML_CHECK_ERROR(e, API); #else -#define V4_API(API, ...) LOG(FATAL) << "CLML Error:" #API " - Incompatible V4 API call\n"; +#define V4_API(API, ...) \ + TVM_FFI_THROW(InternalError) << "CLML Error:" #API " - Incompatible V4 API call\n"; #endif #if CL_QCOM_ML_OPS_H_MAJOR_VERSION > 2 @@ -73,7 +74,8 @@ ->API(__VA_ARGS__); \ CLML_CHECK_ERROR(e, API); #else -#define V3_API(API, ...) LOG(FATAL) << "CLML Error:" #API " - Incompatible V3 API call\n"; +#define V3_API(API, ...) \ + TVM_FFI_THROW(InternalError) << "CLML Error:" #API " - Incompatible V3 API call\n"; #endif #if CL_QCOM_ML_OPS_H_MAJOR_VERSION > 1 @@ -82,7 +84,8 @@ ->API(__VA_ARGS__); \ CLML_CHECK_ERROR(e, API); #else -#define V2_API(API, ...) LOG(FATAL) << "CLML Error:" #API " - Incompatible V2 API call\n"; +#define V2_API(API, ...) \ + TVM_FFI_THROW(InternalError) << "CLML Error:" #API " - Incompatible V2 API call\n"; #endif #define V1_API(API, ...) \ @@ -90,25 +93,25 @@ ->API(__VA_ARGS__); \ CLML_CHECK_ERROR(e, API); -#define CLML_CALL(API, ...) \ - { \ - cl_int e; \ - switch (CLMLWorkspace::Global()->target_major) { \ - case 1: \ - V1_API(API, __VA_ARGS__); \ - break; \ - case 2: \ - V2_API(API, __VA_ARGS__); \ - break; \ - case 3: \ - V3_API(API, __VA_ARGS__); \ - break; \ - case 4: \ - V4_API(API, __VA_ARGS__); \ - break; \ - default: \ - LOG(FATAL) << "CLML Error:" #API " - Unsupported target version \n"; \ - } \ +#define CLML_CALL(API, ...) \ + { \ + cl_int e; \ + switch (CLMLWorkspace::Global()->target_major) { \ + case 1: \ + V1_API(API, __VA_ARGS__); \ + break; \ + case 2: \ + V2_API(API, __VA_ARGS__); \ + break; \ + case 3: \ + V3_API(API, __VA_ARGS__); \ + break; \ + case 4: \ + V4_API(API, __VA_ARGS__); \ + break; \ + default: \ + TVM_FFI_THROW(InternalError) << "CLML Error:" #API " - Unsupported target version \n"; \ + } \ } #define CLML_CALL_VERSIONED(APICALL, VERSION, ...) CAT(CAT(V, VERSION), _API)(APICALL, __VA_ARGS__) @@ -119,14 +122,14 @@ break; // clCreateMLOpClipQCOM -#define CLML_CALL_clCreateMLOpClipQCOM(...) \ - cl_int e; \ - switch (CLMLWorkspace::Global()->target_major) { \ - CALL_CASE(2, clCreateMLOpClipQCOM, __VA_ARGS__) \ - CALL_CASE(3, clCreateMLOpClipQCOM, __VA_ARGS__) \ - CALL_CASE(4, clCreateMLOpClipQCOM, __VA_ARGS__) \ - default: \ - LOG(FATAL) << "CLML Error: - Unsupported target version \n"; \ +#define CLML_CALL_clCreateMLOpClipQCOM(...) \ + cl_int e; \ + switch (CLMLWorkspace::Global()->target_major) { \ + CALL_CASE(2, clCreateMLOpClipQCOM, __VA_ARGS__) \ + CALL_CASE(3, clCreateMLOpClipQCOM, __VA_ARGS__) \ + CALL_CASE(4, clCreateMLOpClipQCOM, __VA_ARGS__) \ + default: \ + TVM_FFI_THROW(InternalError) << "CLML Error: - Unsupported target version \n"; \ } // clCreateMLTensorQCOM and clCreateMLTensorWithUsageQCOM @@ -137,15 +140,15 @@ TENSOR) \ CALL_CASE(VERSION, clCreateMLTensorWithUsageQCOM, CONTEXT, TENSORPROPS, TENSORDESC, USAGE, TENSOR) -#define CLML_CALL_clCreateMLTensorQCOM(...) \ - cl_int e; \ - switch (CLMLWorkspace::Global()->target_major) { \ - CALL_clCreateMLTensorQCOM(1, __VA_ARGS__); \ - CALL_clCreateMLTensorQCOM(2, __VA_ARGS__); \ - CALL_clCreateMLTensorQCOM(3, __VA_ARGS__); \ - CALL_clCreateMLTensorWithUsageQCOM(4, __VA_ARGS__); \ - default: \ - LOG(FATAL) << "CLML Error: - Unsupported target version \n"; \ +#define CLML_CALL_clCreateMLTensorQCOM(...) \ + cl_int e; \ + switch (CLMLWorkspace::Global()->target_major) { \ + CALL_clCreateMLTensorQCOM(1, __VA_ARGS__); \ + CALL_clCreateMLTensorQCOM(2, __VA_ARGS__); \ + CALL_clCreateMLTensorQCOM(3, __VA_ARGS__); \ + CALL_clCreateMLTensorWithUsageQCOM(4, __VA_ARGS__); \ + default: \ + TVM_FFI_THROW(InternalError) << "CLML Error: - Unsupported target version \n"; \ } /* Version compatibility for CLML Tensor creation */ diff --git a/src/runtime/contrib/clml/clml_utils.cc b/src/runtime/contrib/clml/clml_utils.cc index 356adc5c570d..8a3f978f237c 100644 --- a/src/runtime/contrib/clml/clml_utils.cc +++ b/src/runtime/contrib/clml/clml_utils.cc @@ -43,7 +43,7 @@ void CopyDataToCLMLTensor(std::shared_ptr tensor, cl_event evt = nullptr; CLML_CALL(clEnqueueWriteMLTensorDataQCOM, CLML_QUEUE, data, layout, tensor->tensor, tensor->memory, 0, nullptr, &evt); - ICHECK(evt != nullptr) << "clEnqueueWriteMLTensorDataQCOM"; + TVM_FFI_ICHECK(evt != nullptr) << "clEnqueueWriteMLTensorDataQCOM"; } /*! @@ -61,7 +61,7 @@ void CopyDataFromCLMLTensor(std::shared_ptr tenso CLML_CALL(clEnqueueReadMLTensorDataQCOM, CLML_QUEUE, tensor->tensor, tensor->memory, data, layout, 0, nullptr, &readEvent); result = clWaitForEvents(1, &readEvent); - ICHECK(result == CL_SUCCESS) << "clWaitForEvents:" << result; + TVM_FFI_ICHECK(result == CL_SUCCESS) << "clWaitForEvents:" << result; } /*! @@ -81,7 +81,7 @@ cl_ml_tensor_qcom DeviceMakeCLMLTensor(cl_context context, tensor_dims_t dims, cl_ml_tensor_desc_qcom desc = { dtype, layout, dims.n, dims.c, dims.h, dims.w, 0, CL_TENSOR_DIMENSIONS_4D_QCOM, {0}}; CLML_CALL_clCreateMLTensorQCOM(CLML_CTX, nullptr, &desc, usage, &tensor); - ICHECK(tensor) << "clCreateMLTensorQCOM"; + TVM_FFI_ICHECK(tensor) << "clCreateMLTensorQCOM"; return tensor; } @@ -97,7 +97,7 @@ cl_mem AllocateDDRTensorMemory(size_t size) { cl_mem buffer = nullptr; buffer = clCreateBuffer(CLML_CTX, CL_MEM_READ_WRITE, size, nullptr, &result); - ICHECK(result == CL_SUCCESS) << "clCreateBuffer:" << result; + TVM_FFI_ICHECK(result == CL_SUCCESS) << "clCreateBuffer:" << result; return buffer; } @@ -119,7 +119,7 @@ cl_mem AllocateOnChipTensorMemory(size_t size, cl_uint on_chip_mem_offset) { LOG_MEM << "On-Chip Alloc:" << size << " Offset:" << on_chip_mem_offset; buffer = clCreateBufferWithProperties(CLML_CTX, on_chip_buff_prop, CL_MEM_READ_WRITE, size, nullptr, &result); - ICHECK(result == CL_SUCCESS) << "clCreateBufferWithProperties:" << result; + TVM_FFI_ICHECK(result == CL_SUCCESS) << "clCreateBufferWithProperties:" << result; return buffer; } @@ -152,7 +152,7 @@ cl_channel_type MakeCLDataType(const DLDataType& data_type) { } else if (data_type.code == DLDataTypeCode::kDLFloat && data_type.bits == 16) { return CL_HALF_FLOAT; } else { - LOG(FATAL) << "Datatype " << data_type << " unsupported by CLML runtime"; + TVM_FFI_THROW(InternalError) << "Datatype " << data_type << " unsupported by CLML runtime"; } } @@ -172,7 +172,7 @@ cl_arithmetic_mode_qcom MakeCLArithMode(const cl_channel_type& data_type, } else if (data_type == CL_HALF_FLOAT && acc_type == CL_HALF_FLOAT) { return CL_ARITHMETIC_MODE_FP16_QCOM; } else { - LOG(FATAL) << "Datatype " << data_type << " unsupported by CLML runtime"; + TVM_FFI_THROW(InternalError) << "Datatype " << data_type << " unsupported by CLML runtime"; } } diff --git a/src/runtime/contrib/cublas/cublas.cc b/src/runtime/contrib/cublas/cublas.cc index 715172ecd8f9..6e667ee37878 100644 --- a/src/runtime/contrib/cublas/cublas.cc +++ b/src/runtime/contrib/cublas/cublas.cc @@ -144,7 +144,7 @@ void CallCublasLt(cublasLtHandle_t hdl, cudaStream_t stream, const DLTensor* C, bool transa, bool transb, void* workspace_ptr, size_t workspace_size, cublasLtEpilogue_t epilogue, std::optional dq_scale) { - ICHECK(TypeEqual(A->dtype, B->dtype)); + TVM_FFI_ICHECK(TypeEqual(A->dtype, B->dtype)); // Reversed strides indicates an in-place transpose operation. transa = IsInPlaceTransposed(A) ? !transa : transa; transb = IsInPlaceTransposed(B) ? !transb : transb; @@ -170,7 +170,7 @@ void CallCublasLt(cublasLtHandle_t hdl, cudaStream_t stream, } else if (TypeMatch(A->dtype, kDLInt, 8)) { ab_type = CUDA_R_8I; } else if (TypeMatch(A->dtype, DataType::TypeCode::kFloat8_e4m3fn, 8)) { - ICHECK(TypeMatch(B->dtype, DataType::TypeCode::kFloat8_e4m3fn, 8)); + TVM_FFI_ICHECK(TypeMatch(B->dtype, DataType::TypeCode::kFloat8_e4m3fn, 8)); ab_type = CUDA_R_8F_E4M3; } @@ -274,7 +274,7 @@ void CallCublasLt(cublasLtHandle_t hdl, cudaStream_t stream, // cuBLASLt does not seem to support batched GEMM with one of matrices having // one batch (with batch_stride 0). - ICHECK_EQ(batch_count_A, batch_count_B); + TVM_FFI_ICHECK_EQ(batch_count_A, batch_count_B); set_batch(A_desc, batch_count_A, batch_stride_A); set_batch(B_desc, batch_count_B, batch_stride_B); @@ -327,19 +327,19 @@ inline void CallLtIgemm(ffi::PackedArgs args, ffi::Any* ret, cublasLtHandle_t hd int lda = M * K / (roundoff(K, 32) / 32); int ldb = K * N / (roundoff(K, 32) / 32); int ldc = M * N_out / (roundoff(N_out, 32) / 32); - ICHECK_EQ(A->ndim, 2); - ICHECK_EQ(B->ndim, 2); - ICHECK_EQ(C->ndim, 2); + TVM_FFI_ICHECK_EQ(A->ndim, 2); + TVM_FFI_ICHECK_EQ(B->ndim, 2); + TVM_FFI_ICHECK_EQ(C->ndim, 2); - ICHECK_EQ(ElementStride(A), 1); - ICHECK_EQ(ElementStride(B), 1); - ICHECK_EQ(ElementStride(C), 1); + TVM_FFI_ICHECK_EQ(ElementStride(A), 1); + TVM_FFI_ICHECK_EQ(ElementStride(B), 1); + TVM_FFI_ICHECK_EQ(ElementStride(C), 1); - ICHECK(TypeEqual(A->dtype, B->dtype)); - ICHECK(TypeMatch(A->dtype, kDLInt, 8)); - ICHECK(TypeMatch(C->dtype, kDLInt, 32)); + TVM_FFI_ICHECK(TypeEqual(A->dtype, B->dtype)); + TVM_FFI_ICHECK(TypeMatch(A->dtype, kDLInt, 8)); + TVM_FFI_ICHECK(TypeMatch(C->dtype, kDLInt, 32)); - ICHECK(CheckMixPrecisionType(A->dtype, C->dtype)) << "Unsupported data type"; + TVM_FFI_ICHECK(CheckMixPrecisionType(A->dtype, C->dtype)) << "Unsupported data type"; int32_t alpha = args.size() > 5 ? args[5].cast() : 1; int32_t beta = args.size() > 6 ? args[6].cast() : 0; cublasLtMatrixLayout_t Adesc = nullptr, Bdesc = nullptr, Cdesc = nullptr; @@ -386,27 +386,27 @@ inline void CallGemmEx(ffi::PackedArgs args, ffi::Any* ret, cublasHandle_t hdl) auto C = args[2].cast(); bool transa = args[3].cast(); bool transb = args[4].cast(); - ICHECK_EQ(A->ndim, 2); - ICHECK_EQ(B->ndim, 2); - ICHECK_EQ(C->ndim, 2); + TVM_FFI_ICHECK_EQ(A->ndim, 2); + TVM_FFI_ICHECK_EQ(B->ndim, 2); + TVM_FFI_ICHECK_EQ(C->ndim, 2); - ICHECK_EQ(ElementStride(A), 1); - ICHECK_EQ(ElementStride(B), 1); - ICHECK_EQ(ElementStride(C), 1); + TVM_FFI_ICHECK_EQ(ElementStride(A), 1); + TVM_FFI_ICHECK_EQ(ElementStride(B), 1); + TVM_FFI_ICHECK_EQ(ElementStride(C), 1); - ICHECK(TypeEqual(A->dtype, B->dtype)); + TVM_FFI_ICHECK(TypeEqual(A->dtype, B->dtype)); // C can never be transposed. - ICHECK(!IsInPlaceTransposed(C)); + TVM_FFI_ICHECK(!IsInPlaceTransposed(C)); // Reversed strides indicates an in-place transpose operation. transa = IsInPlaceTransposed(A) ? !transa : transa; transb = IsInPlaceTransposed(B) ? !transb : transb; - ICHECK(CheckMixPrecisionType(A->dtype, C->dtype)) << "Unsupported data type"; - ICHECK(!TypeMatch(A->dtype, kDLInt, 8) || ColumnStride(A) % 4 == 0) + TVM_FFI_ICHECK(CheckMixPrecisionType(A->dtype, C->dtype)) << "Unsupported data type"; + TVM_FFI_ICHECK(!TypeMatch(A->dtype, kDLInt, 8) || ColumnStride(A) % 4 == 0) << "leading dimension must divide 4 for int8 gemm"; - ICHECK(!TypeMatch(B->dtype, kDLInt, 8) || ColumnStride(B) % 4 == 0) + TVM_FFI_ICHECK(!TypeMatch(B->dtype, kDLInt, 8) || ColumnStride(B) % 4 == 0) << "leading dimension must divide 4 for int8 gemm"; double alpha = args.size() > 5 ? args[5].cast() : 1.0; double beta = args.size() > 6 ? args[6].cast() : 0.0; @@ -444,28 +444,28 @@ inline void CallBatchGemmEx(ffi::PackedArgs args, ffi::Any* ret, cublasHandle_t auto C = args[2].cast(); bool transa = args[3].cast(); bool transb = args[4].cast(); - ICHECK_EQ(A->ndim, 3); - ICHECK_EQ(B->ndim, 3); - ICHECK_EQ(C->ndim, 3); + TVM_FFI_ICHECK_EQ(A->ndim, 3); + TVM_FFI_ICHECK_EQ(B->ndim, 3); + TVM_FFI_ICHECK_EQ(C->ndim, 3); int batch_size = BatchCount3D(C); - ICHECK_EQ(ElementStride3D(A), 1); - ICHECK_EQ(ElementStride3D(B), 1); - ICHECK_EQ(ElementStride3D(C), 1); + TVM_FFI_ICHECK_EQ(ElementStride3D(A), 1); + TVM_FFI_ICHECK_EQ(ElementStride3D(B), 1); + TVM_FFI_ICHECK_EQ(ElementStride3D(C), 1); - ICHECK(TypeEqual(A->dtype, B->dtype)); + TVM_FFI_ICHECK(TypeEqual(A->dtype, B->dtype)); // C can never be transposed. - ICHECK(!IsInPlaceTransposed3D(C)); + TVM_FFI_ICHECK(!IsInPlaceTransposed3D(C)); // Reversed strides indicates an in-place transpose operation. transa = IsInPlaceTransposed3D(A) ? !transa : transa; transb = IsInPlaceTransposed3D(B) ? !transb : transb; - ICHECK(CheckMixPrecisionType(A->dtype, C->dtype, true)) << "Unsupported data type"; - ICHECK(!TypeMatch(A->dtype, kDLInt, 8) || ColumnStride3D(A) % 4 == 0) + TVM_FFI_ICHECK(CheckMixPrecisionType(A->dtype, C->dtype, true)) << "Unsupported data type"; + TVM_FFI_ICHECK(!TypeMatch(A->dtype, kDLInt, 8) || ColumnStride3D(A) % 4 == 0) << "leading dimension must divide 4 for int8 gemm"; - ICHECK(!TypeMatch(B->dtype, kDLInt, 8) || ColumnStride3D(B) % 4 == 0) + TVM_FFI_ICHECK(!TypeMatch(B->dtype, kDLInt, 8) || ColumnStride3D(B) % 4 == 0) << "leading dimension must divide 4 for int8 gemm"; double alpha = args.size() > 5 ? args[5].cast() : 1.0; double beta = args.size() > 6 ? args[6].cast() : 0.0; @@ -484,8 +484,8 @@ inline void CallBatchGemmEx(ffi::PackedArgs args, ffi::Any* ret, cublasHandle_t B_stride = 0; } } else { - ICHECK_EQ(batch_size_a, batch_size); - ICHECK_EQ(batch_size_b, batch_size); + TVM_FFI_ICHECK_EQ(batch_size_a, batch_size); + TVM_FFI_ICHECK_EQ(batch_size_b, batch_size); } cudaDataType_t cuda_in_type = GetCudaDataType(A->dtype); @@ -528,8 +528,8 @@ TVM_FFI_STATIC_INIT_BLOCK() { CUBLASTryEnableTensorCore(entry_ptr->handle); if (TypeEqual(A->dtype, C->dtype)) { - ICHECK(TypeMatch(A->dtype, kDLFloat, 16) || TypeMatch(A->dtype, kDLFloat, 32) || - TypeMatch(A->dtype, kDLFloat, 64)); + TVM_FFI_ICHECK(TypeMatch(A->dtype, kDLFloat, 16) || TypeMatch(A->dtype, kDLFloat, 32) || + TypeMatch(A->dtype, kDLFloat, 64)); if (TypeMatch(A->dtype, kDLFloat, 16)) CallGemm(args, ret, CublasHgemmOp(entry_ptr->handle)); @@ -554,7 +554,7 @@ TVM_FFI_STATIC_INIT_BLOCK() { CUBLASTryEnableTensorCore(entry_ptr->handle); - ICHECK(TypeMatch(A->dtype, kDLInt, 8)) << "Expects dtype to be int8\n"; + TVM_FFI_ICHECK(TypeMatch(A->dtype, kDLInt, 8)) << "Expects dtype to be int8\n"; cublasLtHandle_t ltHandle; CHECK_CUBLAS_ERROR(cublasLtCreate(<Handle)); cudaStream_t stream = @@ -576,8 +576,8 @@ TVM_FFI_STATIC_INIT_BLOCK() { CUBLASTryEnableTensorCore(entry_ptr->handle); if (TypeEqual(A->dtype, C->dtype)) { - ICHECK(TypeMatch(A->dtype, kDLFloat, 16) || TypeMatch(A->dtype, kDLFloat, 32) || - TypeMatch(A->dtype, kDLFloat, 64)); + TVM_FFI_ICHECK(TypeMatch(A->dtype, kDLFloat, 16) || TypeMatch(A->dtype, kDLFloat, 32) || + TypeMatch(A->dtype, kDLFloat, 64)); if (TypeMatch(A->dtype, kDLFloat, 16)) CallBatchGemm(args, ret, CublasHgemmBatchOp(entry_ptr->handle)); diff --git a/src/runtime/contrib/cublas/cublas_json_runtime.cc b/src/runtime/contrib/cublas/cublas_json_runtime.cc index 08c871baea4f..1b5d4ebc8570 100644 --- a/src/runtime/contrib/cublas/cublas_json_runtime.cc +++ b/src/runtime/contrib/cublas/cublas_json_runtime.cc @@ -58,7 +58,7 @@ class CublasJSONRuntime : public JSONRuntimeBase { ObjectPtr sptr_to_self = ffi::GetObjectPtr(this); if (this->symbol_name_ == name) { return ffi::Function([sptr_to_self, this](ffi::PackedArgs args, ffi::Any* rv) { - ICHECK(this->initialized_) << "The module has not been initialized"; + TVM_FFI_ICHECK(this->initialized_) << "The module has not been initialized"; this->Run(args); }); } else { @@ -94,9 +94,9 @@ class CublasJSONRuntime : public JSONRuntimeBase { cudaStream_t stream = static_cast(TVMFFIEnvGetStream(kDLCUDA, device_id)); auto get_input = [this, &dl_tensors](const JSONGraphNode& node, int idx) { - ICHECK_LT(idx, node.GetInputs().size()); + TVM_FFI_ICHECK_LT(idx, node.GetInputs().size()); auto eid = EntryID(node.GetInputs()[idx]); - ICHECK(eid < dl_tensors.size()); + TVM_FFI_ICHECK(eid < dl_tensors.size()); return dl_tensors[eid]; }; @@ -150,7 +150,7 @@ class CublasJSONRuntime : public JSONRuntimeBase { } } - void Run() override { LOG(FATAL) << "Unreachable"; } + void Run() override { TVM_FFI_THROW(InternalError) << "Unreachable"; } }; ffi::Module CublasJSONRuntimeCreate(ffi::String symbol_name, ffi::String graph_json, diff --git a/src/runtime/contrib/cublas/cublas_utils.h b/src/runtime/contrib/cublas/cublas_utils.h index 9c99a83250db..49f93f66dcc2 100644 --- a/src/runtime/contrib/cublas/cublas_utils.h +++ b/src/runtime/contrib/cublas/cublas_utils.h @@ -64,10 +64,10 @@ inline const char* GetCublasErrorString(int error) { } #ifndef CHECK_CUBLAS_ERROR -#define CHECK_CUBLAS_ERROR(fn) \ - do { \ - int error = static_cast(fn); \ - ICHECK_EQ(error, CUBLAS_STATUS_SUCCESS) << "CUBLAS: " << GetCublasErrorString(error); \ +#define CHECK_CUBLAS_ERROR(fn) \ + do { \ + int error = static_cast(fn); \ + TVM_FFI_ICHECK_EQ(error, CUBLAS_STATUS_SUCCESS) << "CUBLAS: " << GetCublasErrorString(error); \ } while (0) // ; intentionally left off. #endif // CHECK_CUBLAS_ERROR @@ -122,7 +122,7 @@ inline cudaDataType_t GetCudaDataType(DLDataType type) { return CUDA_R_16BF; } } - LOG(FATAL) << "Unsupported CUDA type"; + TVM_FFI_THROW(InternalError) << "Unsupported CUDA type"; } /*! \brief Execute matrix multiply followed by the specified epilogue, using cuBLASLt. */ diff --git a/src/runtime/contrib/cudnn/cudnn_frontend/attention.cc b/src/runtime/contrib/cudnn/cudnn_frontend/attention.cc index fbde314bc6ae..5dea47176c1a 100644 --- a/src/runtime/contrib/cudnn/cudnn_frontend/attention.cc +++ b/src/runtime/contrib/cudnn/cudnn_frontend/attention.cc @@ -39,7 +39,7 @@ void CuDNNSDPARunnerNode::Init(int64_t batch, int64_t seq_len, int64_t num_heads const std::string& layout) { graph_ = std::make_unique(); - CHECK(data_type.code == DLDataTypeCode::kDLFloat && data_type.bits == 16) + TVM_FFI_ICHECK(data_type.code == DLDataTypeCode::kDLFloat && data_type.bits == 16) << "Only float16 is supported"; graph_->set_io_data_type(cudnn_frontend::DataType_t::HALF) @@ -69,7 +69,7 @@ void CuDNNSDPARunnerNode::Init(int64_t batch, int64_t seq_len, int64_t num_heads offset_k_ = num_heads * head_size; offset_v_ = offset_k_ + num_kv_heads * head_size; } else if (layout == "SBN3H") { - CHECK_EQ(num_kv_heads, num_heads); + TVM_FFI_ICHECK_EQ(num_kv_heads, num_heads); int64_t stride_H = 1; int64_t stride_N = head_size + head_size + head_size_v; int64_t stride_B = num_heads * stride_N; @@ -79,7 +79,7 @@ void CuDNNSDPARunnerNode::Init(int64_t batch, int64_t seq_len, int64_t num_heads offset_k_ = head_size; offset_v_ = offset_k_ * 2; } else { - LOG(FATAL) << "Unsupported layout: " << layout; + TVM_FFI_THROW(InternalError) << "Unsupported layout: " << layout; } q_desc = q_desc.set_dim({batch, num_heads, seq_len, head_size}).set_stride(q_stride); @@ -96,7 +96,7 @@ void CuDNNSDPARunnerNode::Init(int64_t batch, int64_t seq_len, int64_t num_heads auto k = graph_->tensor(k_desc); auto v = graph_->tensor(v_desc); auto [o, stats] = graph_->sdpa(q, k, v, sdpa_options); - CHECK(stats == nullptr); + TVM_FFI_ICHECK(stats == nullptr); o->set_output(true).set_dim({batch, num_heads, seq_len, head_size_v}).set_stride(o_stride); int device_id; CUDA_CALL(cudaGetDevice(&device_id)); @@ -112,7 +112,7 @@ void CuDNNSDPARunnerNode::Run(const DLTensor* qkv, DLTensor* workspace, DLTensor auto* out_ptr = reinterpret_cast(out->data) + out->byte_offset; size_t workspace_size = graph_->get_workspace_size(); - CHECK_LE(workspace_size, workspace->shape[0]) << "Workspace size too small"; + TVM_FFI_ICHECK_LE(workspace_size, workspace->shape[0]) << "Workspace size too small"; std::unordered_map inputs = { {kTensorIDQ, q_ptr}, {kTensorIDK, k_ptr}, {kTensorIDV, v_ptr}, {kTensorIDOut, out_ptr}}; diff --git a/src/runtime/contrib/cudnn/cudnn_frontend/attention.h b/src/runtime/contrib/cudnn/cudnn_frontend/attention.h index 248d44d9d65f..9761fa20bebe 100644 --- a/src/runtime/contrib/cudnn/cudnn_frontend/attention.h +++ b/src/runtime/contrib/cudnn/cudnn_frontend/attention.h @@ -31,10 +31,10 @@ #include #include -#define CUDNN_FRONTEND_CALL(func) \ - do { \ - auto status = (func); \ - CHECK(status.is_good()) << status.get_message(); \ +#define CUDNN_FRONTEND_CALL(func) \ + do { \ + auto status = (func); \ + TVM_FFI_ICHECK(status.is_good()) << status.get_message(); \ } while (0) namespace tvm { diff --git a/src/runtime/contrib/cudnn/cudnn_json_runtime.cc b/src/runtime/contrib/cudnn/cudnn_json_runtime.cc index d8cb5e654ad5..df38f960d294 100644 --- a/src/runtime/contrib/cudnn/cudnn_json_runtime.cc +++ b/src/runtime/contrib/cudnn/cudnn_json_runtime.cc @@ -64,7 +64,7 @@ class cuDNNJSONRuntime : public JSONRuntimeBase { } else if (op_name.find("attention") != std::string::npos) { op_execs_[i] = GetAttentionExec(node); } else { - LOG(FATAL) << "Unsupported op: " << op_name; + TVM_FFI_THROW(InternalError) << "Unsupported op: " << op_name; } } } @@ -82,9 +82,9 @@ class cuDNNJSONRuntime : public JSONRuntimeBase { private: const DLTensor* GetInput(const JSONGraphNode& node, const int idx) { - ICHECK_LT(idx, node.GetInputs().size()); + TVM_FFI_ICHECK_LT(idx, node.GetInputs().size()); auto eid = EntryID(node.GetInputs()[idx]); - ICHECK(eid < data_entry_.size()); + TVM_FFI_ICHECK(eid < data_entry_.size()); return data_entry_[eid]; } @@ -136,7 +136,7 @@ class cuDNNJSONRuntime : public JSONRuntimeBase { } else if (layout == "NHWC") { format = CUDNN_TENSOR_NHWC; } else { - LOG(FATAL) << "Unsupported layout: " << layout; + TVM_FFI_THROW(InternalError) << "Unsupported layout: " << layout; } int act = CUDNN_ACTIVATION_IDENTITY; @@ -205,15 +205,15 @@ class cuDNNJSONRuntime : public JSONRuntimeBase { int64_t batch, seq_len; if (layout == "BS3NH") { - ICHECK_EQ(qkv_shapes.size(), 3); + TVM_FFI_ICHECK_EQ(qkv_shapes.size(), 3); batch = qkv_shapes[0]; seq_len = qkv_shapes[1]; } else if (layout == "SBN3H") { - ICHECK_EQ(qkv_shapes.size(), 4); + TVM_FFI_ICHECK_EQ(qkv_shapes.size(), 4); batch = qkv_shapes[1]; seq_len = qkv_shapes[0]; } else { - LOG(FATAL) << "Unsupported layout: " << layout; + TVM_FFI_THROW(InternalError) << "Unsupported layout: " << layout; } double scale = 1 / std::sqrt(head_size); if (node.HasAttr("scale")) { @@ -230,7 +230,7 @@ class cuDNNJSONRuntime : public JSONRuntimeBase { runner->Run(qkv, workspace, out); }; #else - LOG(FATAL) << "Please build with CUDNN frontend to use attention op"; + TVM_FFI_THROW(InternalError) << "Please build with CUDNN frontend to use attention op"; return nullptr; #endif } diff --git a/src/runtime/contrib/cudnn/cudnn_utils.cc b/src/runtime/contrib/cudnn/cudnn_utils.cc index 4a9581f25959..5793febf01ef 100644 --- a/src/runtime/contrib/cudnn/cudnn_utils.cc +++ b/src/runtime/contrib/cudnn/cudnn_utils.cc @@ -45,10 +45,10 @@ cudnnDataType_t CuDNNDataType::DLTypeToCuDNNType(const DLDataType& dtype) { else if (dtype.bits == 8 && dtype.lanes == 4) return CUDNN_DATA_INT8x4; else - LOG(FATAL) << "Unsupported type"; + TVM_FFI_THROW(InternalError) << "Unsupported type"; break; case kDLUInt: - LOG(FATAL) << "Unsupported type"; + TVM_FFI_THROW(InternalError) << "Unsupported type"; break; case kDLFloat: if (dtype.bits == 32 && dtype.lanes == 1) @@ -58,7 +58,7 @@ cudnnDataType_t CuDNNDataType::DLTypeToCuDNNType(const DLDataType& dtype) { else if (dtype.bits == 16 && dtype.lanes == 1) return CUDNN_DATA_HALF; else - LOG(FATAL) << "Unsupported type"; + TVM_FFI_THROW(InternalError) << "Unsupported type"; break; } return CUDNN_DATA_FLOAT; @@ -124,7 +124,7 @@ CuDNNThreadEntry* CuDNNThreadEntry::ThreadLocal(Device curr_device, bool check_e static thread_local CuDNNThreadEntry inst; auto* res = &inst; if (check_exists) { - ICHECK(res->exists()) << "CUDNN_STATUS_NOT_INITIALIZED"; + TVM_FFI_ICHECK(res->exists()) << "CUDNN_STATUS_NOT_INITIALIZED"; } cudaStream_t stream = @@ -225,7 +225,8 @@ void SetConvDescriptors(CuDNNThreadEntry* entry_ptr, int format, int dims, int g static_cast(y_dim[ni]), static_cast(y_dim[ci]), static_cast(y_dim[hi]), static_cast(y_dim[wi]))); } else { - ICHECK_EQ(format, 0) << "Use of layout CUDNN_TENSOR_NHWC is supported only for 4-D tensors."; + TVM_FFI_ICHECK_EQ(format, 0) + << "Use of layout CUDNN_TENSOR_NHWC is supported only for 4-D tensors."; CUDNN_CALL(cudnnSetConvolutionNdDescriptor(entry_ptr->conv_entry.conv_desc, dims, pad, stride, dilation, entry_ptr->conv_entry.mode, diff --git a/src/runtime/contrib/cudnn/cudnn_utils.h b/src/runtime/contrib/cudnn/cudnn_utils.h index 499cc5d6c9e5..58eac57c679d 100644 --- a/src/runtime/contrib/cudnn/cudnn_utils.h +++ b/src/runtime/contrib/cudnn/cudnn_utils.h @@ -35,10 +35,10 @@ namespace tvm { namespace contrib { -#define CUDNN_CALL(func) \ - { \ - cudnnStatus_t e = (func); \ - ICHECK_EQ(e, CUDNN_STATUS_SUCCESS) << "cuDNN: " << cudnnGetErrorString(e); \ +#define CUDNN_CALL(func) \ + { \ + cudnnStatus_t e = (func); \ + TVM_FFI_ICHECK_EQ(e, CUDNN_STATUS_SUCCESS) << "cuDNN: " << cudnnGetErrorString(e); \ } /*! breif Convert DLTensor type to CuDNN type */ diff --git a/src/runtime/contrib/cudnn/softmax.cc b/src/runtime/contrib/cudnn/softmax.cc index 10df70670c70..d494fb334946 100644 --- a/src/runtime/contrib/cudnn/softmax.cc +++ b/src/runtime/contrib/cudnn/softmax.cc @@ -39,7 +39,7 @@ void softmax_impl(cudnnSoftmaxAlgorithm_t alg, ffi::PackedArgs args, ffi::Any* r int ndim = x->ndim; int64_t* shape = x->shape; if (axis < 0) axis += ndim; - ICHECK(axis >= 0 && axis < ndim); + TVM_FFI_ICHECK(axis >= 0 && axis < ndim); int device_id; CUDA_CALL(cudaGetDevice(&device_id)); CuDNNThreadEntry* entry_ptr = CuDNNThreadEntry::ThreadLocal(DLDevice{kDLCUDA, device_id}); diff --git a/src/runtime/contrib/curand/curand.cc b/src/runtime/contrib/curand/curand.cc index 53505770f83a..4dd0ca145c37 100644 --- a/src/runtime/contrib/curand/curand.cc +++ b/src/runtime/contrib/curand/curand.cc @@ -28,10 +28,10 @@ namespace tvm { namespace runtime { namespace curand { -#define TVM_CURAND_CALL(func) \ - { \ - curandStatus_t e = (func); \ - ICHECK(e == CURAND_STATUS_SUCCESS) << "cuRAND error: " << e; \ +#define TVM_CURAND_CALL(func) \ + { \ + curandStatus_t e = (func); \ + TVM_FFI_ICHECK(e == CURAND_STATUS_SUCCESS) << "cuRAND error: " << e; \ } class CURandGenerator { @@ -77,8 +77,8 @@ struct DeferredFunc { void RandomFill(DLTensor* tensor) { static DeviceAPI* cuda_api = GetCUDADeviceAPI(); - CHECK(tensor->device.device_type == DLDeviceType::kDLCUDA) - << "ValueError: cuRAND only works on CUDA devices"; + TVM_FFI_CHECK(tensor->device.device_type == DLDeviceType::kDLCUDA, ValueError) + << "cuRAND only works on CUDA devices"; int64_t tensor_size = GetTensorSize(tensor); int64_t actual_size = tensor_size % 2 == 0 ? tensor_size : tensor_size + 1; if (tensor->dtype.code == DLDataTypeCode::kDLFloat && tensor->dtype.bits == 16) { @@ -108,7 +108,7 @@ void RandomFill(DLTensor* tensor) { CURandGenerator().Generate64bit(tensor->data, actual_size); } } else { - LOG(FATAL) << "ValueError: Unsupported dtype: " << tensor->dtype; + TVM_FFI_THROW(ValueError) << "Unsupported dtype: " << tensor->dtype; } cuda_api->StreamSync(tensor->device, nullptr); } diff --git a/src/runtime/contrib/dnnl/dnnl.cc b/src/runtime/contrib/dnnl/dnnl.cc index 972c61e9436e..77dcc2a0b94e 100644 --- a/src/runtime/contrib/dnnl/dnnl.cc +++ b/src/runtime/contrib/dnnl/dnnl.cc @@ -65,7 +65,7 @@ inline dnnl::memory::desc GenDNNLMemDescByShape(const dnnl::memory::dims& shape, data_md = dnnl::memory::desc({shape, dtype, tag::abcde}); break; default: - LOG(FATAL) << "Unsupported data shape dimension: " << shape.size(); + TVM_FFI_THROW(InternalError) << "Unsupported data shape dimension: " << shape.size(); break; } return data_md; @@ -332,7 +332,7 @@ extern "C" void dnnl_binary_op(float* data, float* weight, float* out, int algo_ algo = algorithm::binary_mul; break; default: - LOG(FATAL) << "Unsupported dnnl algorithm: " << algo_type; + TVM_FFI_THROW(InternalError) << "Unsupported dnnl algorithm: " << algo_type; break; } diff --git a/src/runtime/contrib/dnnl/dnnl_json_runtime.cc b/src/runtime/contrib/dnnl/dnnl_json_runtime.cc index aa48fcb19b21..2a2bb7ed867f 100644 --- a/src/runtime/contrib/dnnl/dnnl_json_runtime.cc +++ b/src/runtime/contrib/dnnl/dnnl_json_runtime.cc @@ -61,7 +61,7 @@ class DNNLJSONRuntime : public JSONRuntimeBase { const char* kind() const override { return "dnnl_json"; } void Init(const ffi::Array& consts) override { - ICHECK_EQ(consts.size(), const_idx_.size()) + TVM_FFI_ICHECK_EQ(consts.size(), const_idx_.size()) << "The number of input constants must match the number of required."; // Setup constants entries for weights. @@ -70,7 +70,7 @@ class DNNLJSONRuntime : public JSONRuntimeBase { } /* Unused stub implementation */ - void Run() override { LOG(FATAL) << "Unreachable code"; } + void Run() override { TVM_FFI_THROW(InternalError) << "Unreachable code"; } /* Thread safe implementation of Run. Keep runtime instance immutable */ void Run(const ffi::PackedArgs& args) const { @@ -104,9 +104,9 @@ class DNNLJSONRuntime : public JSONRuntimeBase { ObjectPtr sptr_to_self = ffi::GetObjectPtr(this); if (this->symbol_name_ == name) { return ffi::Function([sptr_to_self, this](ffi::PackedArgs args, ffi::Any* rv) { - ICHECK(this->initialized_) << "The module has not been initialized"; + TVM_FFI_ICHECK(this->initialized_) << "The module has not been initialized"; - ICHECK_EQ(args.size(), input_var_eid_.size() + outputs_.size()) + TVM_FFI_ICHECK_EQ(args.size(), input_var_eid_.size() + outputs_.size()) << "Found mismatch in the number of provided data entries and required."; Run(args); @@ -152,7 +152,7 @@ class DNNLJSONRuntime : public JSONRuntimeBase { auto sum_scl_tr = GetInputByName(nid, "sum_scl_idx"); if (o_scl_tr) { - ICHECK(o_scl_tr.IsConstant()); + TVM_FFI_ICHECK(o_scl_tr.IsConstant()); auto data = o_scl_tr.GetConstDataLikeVec(); attr.set_output_scales(data.size() == 1 ? 0 : (1 << 1), data); } @@ -260,7 +260,7 @@ class DNNLJSONRuntime : public JSONRuntimeBase { for (size_t nid = 0; nid < nodes_.size(); ++nid) { const auto& node = nodes_[nid]; if (node.GetOpType() == "kernel") { - ICHECK_EQ(node.GetOpType(), "kernel"); + TVM_FFI_ICHECK_EQ(node.GetOpType(), "kernel"); auto op_name = node.GetOpName(); if (tvm::runtime::regex_match(op_name, deconv_pat) || tvm::runtime::regex_match(op_name, conv_transpose_pat)) { @@ -288,7 +288,7 @@ class DNNLJSONRuntime : public JSONRuntimeBase { } else if ("nn.batch_matmul" == op_name) { BatchMatMul(nid); } else { - LOG(FATAL) << "Unsupported op: " << op_name; + TVM_FFI_THROW(InternalError) << "Unsupported op: " << op_name; } } } @@ -566,7 +566,7 @@ class DNNLJSONRuntime : public JSONRuntimeBase { auto center = GetNodeAttr(node, "center"); auto scale = GetNodeAttr(node, "scale"); - ICHECK(axis == 1 && center && scale) << "Unimplemented BatchNorm case"; + TVM_FFI_ICHECK(axis == 1 && center && scale) << "Unimplemented BatchNorm case"; auto bn_desc = dnnl::batch_normalization_forward::desc( dnnl::prop_kind::forward_inference, src_tr.desc(), epsilon, @@ -576,8 +576,8 @@ class DNNLJSONRuntime : public JSONRuntimeBase { // Concatenate scale and shift tensors auto scale_shift_tr = TensorRequisite::AsIs(bn_prim_desc.weights_desc(), GenUniqueEid()); auto sc_sh_dims = scale_shift_tr.dims(); - ICHECK(sc_sh_dims.size() == 2); - ICHECK(sc_sh_dims[0] == 2); + TVM_FFI_ICHECK(sc_sh_dims.size() == 2); + TVM_FFI_ICHECK(sc_sh_dims[0] == 2); sc_sh_dims[0] /= 2; auto scale_tr = scale_shift_tr.Crop(sc_sh_dims, {0, 0}).Squeeze(); auto shift_tr = scale_shift_tr.Crop(sc_sh_dims, {1, 0}).Squeeze(); @@ -610,7 +610,7 @@ class DNNLJSONRuntime : public JSONRuntimeBase { auto center = GetNodeAttr(node, "center"); auto scale = GetNodeAttr(node, "scale"); - ICHECK(axis == -1 && center && scale) << "Unimplemented LayerNorm case"; + TVM_FFI_ICHECK(axis == -1 && center && scale) << "Unimplemented LayerNorm case"; // LN description. auto lnorm_desc = dnnl::layer_normalization_forward::desc( @@ -623,8 +623,8 @@ class DNNLJSONRuntime : public JSONRuntimeBase { auto scale_shift_tr = TensorRequisite::AsIs(lnorm_prim_desc.weights_desc(), GenUniqueEid()); auto sc_sh_dims = scale_shift_tr.dims(); - ICHECK(sc_sh_dims.size() == 2); - ICHECK(sc_sh_dims[0] == 2); + TVM_FFI_ICHECK(sc_sh_dims.size() == 2); + TVM_FFI_ICHECK(sc_sh_dims[0] == 2); sc_sh_dims[0] /= 2; auto scale_tr = scale_shift_tr.Crop(sc_sh_dims, {0, 0}).Squeeze(); auto shift_tr = scale_shift_tr.Crop(sc_sh_dims, {1, 0}).Squeeze(); @@ -661,7 +661,7 @@ class DNNLJSONRuntime : public JSONRuntimeBase { src_tr = src_tr.TreatAs(src_layout); dst_tr = dst_tr.TreatAs(dst_layout); - ICHECK(src_tr.dims().size() > 2); + TVM_FFI_ICHECK(src_tr.dims().size() > 2); std::vector feature_size; for (size_t i = 2; i < src_tr.dims().size(); i++) { @@ -723,7 +723,7 @@ class DNNLJSONRuntime : public JSONRuntimeBase { auto elt_desc = dnnl::eltwise_forward::desc(dnnl::prop_kind::forward_inference, algo, src_tr.desc(), alpha, beta); auto elt_prim_desc = dnnl::eltwise_forward::primitive_desc(elt_desc, engine_); - ICHECK(src_tr.desc() == elt_prim_desc.dst_desc()); + TVM_FFI_ICHECK(src_tr.desc() == elt_prim_desc.dst_desc()); Submit(dnnl::eltwise_forward(elt_prim_desc), {{DNNL_ARG_SRC, src_tr}, {DNNL_ARG_DST, dst_tr}}); } @@ -742,7 +742,7 @@ class DNNLJSONRuntime : public JSONRuntimeBase { auto softmax_desc = dnnl::softmax_forward::desc(dnnl::prop_kind::forward_inference, src_tr.desc(), axis); auto softmax_prim_desc = dnnl::softmax_forward::primitive_desc(softmax_desc, engine_); - ICHECK(dst_tr.desc() == softmax_prim_desc.dst_desc()); + TVM_FFI_ICHECK(dst_tr.desc() == softmax_prim_desc.dst_desc()); Submit(dnnl::softmax_forward(softmax_prim_desc), {{DNNL_ARG_SRC, src_tr}, {DNNL_ARG_DST, dst_tr}}); @@ -750,7 +750,7 @@ class DNNLJSONRuntime : public JSONRuntimeBase { void Binary(const size_t& nid, dnnl::algorithm algo) { auto node = nodes_[nid]; - ICHECK_EQ(node.GetInputs().size(), 2U); + TVM_FFI_ICHECK_EQ(node.GetInputs().size(), 2U); // Memory and compute description. auto lhs_tr = GetInput(nid, 0); @@ -799,7 +799,7 @@ class DNNLJSONRuntime : public JSONRuntimeBase { const JSONGraphNode& node = nodes_[nid]; - ICHECK_LT(idx, node.GetInputs().size()); + TVM_FFI_ICHECK_LT(idx, node.GetInputs().size()); auto data_entry = node.GetInputs()[idx]; auto shape_arr = nodes_[data_entry.id_].GetOpShape()[data_entry.index_]; @@ -812,8 +812,8 @@ class DNNLJSONRuntime : public JSONRuntimeBase { TensorRequisite res; if (const_dl_tensor) { - ICHECK(const_dl_tensor->data); - ICHECK(ffi::IsContiguous(*const_dl_tensor)); + TVM_FFI_ICHECK(const_dl_tensor->data); + TVM_FFI_ICHECK(ffi::IsContiguous(*const_dl_tensor)); auto mem = dnnl::memory(desc, engine_, const_dl_tensor->data); res = TensorRequisite::AsIs(mem, eid); } else { @@ -831,12 +831,12 @@ class DNNLJSONRuntime : public JSONRuntimeBase { if (idx == -1) return {}; // -1 reserved value for empty input. const JSONGraphNode& node = nodes_[nid]; - ICHECK_LT(idx, node.GetNumOutput()); + TVM_FFI_ICHECK_LT(idx, node.GetNumOutput()); auto shape_arr = node.GetOpShape()[idx]; auto dtype = node.GetOpDataType()[idx]; auto eid = node_row_ptr_[nid] + static_cast(idx); - ICHECK(data_entry_[eid] == nullptr); + TVM_FFI_ICHECK(data_entry_[eid] == nullptr); std::vector shape(shape_arr.begin(), shape_arr.end()); auto desc = MakePlainDesc(shape, dtype); diff --git a/src/runtime/contrib/dnnl/dnnl_tensor_requisite.h b/src/runtime/contrib/dnnl/dnnl_tensor_requisite.h index 689113f62865..0385c73df552 100644 --- a/src/runtime/contrib/dnnl/dnnl_tensor_requisite.h +++ b/src/runtime/contrib/dnnl/dnnl_tensor_requisite.h @@ -135,7 +135,7 @@ class TensorRequisite { /*! \brief Make TR with backward dataflow */ TensorRequisite Backward() const { if (!defined()) return *this; - ICHECK(orig_ == nullptr); + TVM_FFI_ICHECK(orig_ == nullptr); return {t_desc_, orig_, reinterpret_, mem_, eid_, true}; } @@ -164,7 +164,7 @@ class TensorRequisite { TensorRequisite Broadcast(const dnnl::memory::dims& shape) const { if (!defined()) return *this; // nothing for empty TR if (t_desc_.dims() == shape) return *this; - ICHECK(!reverse_data_flow_); + TVM_FFI_ICHECK(!reverse_data_flow_); auto orig = std::make_shared(*this); @@ -175,8 +175,8 @@ class TensorRequisite { auto desc = t_desc_.reshape(extended_dims); for (size_t i = 0; i < extended_dims.size(); i++) { if (extended_dims[i] == shape[i]) continue; - ICHECK(extended_dims[i] == 1); - ICHECK(desc.data.dims[i] == desc.data.padded_dims[i]); + TVM_FFI_ICHECK(extended_dims[i] == 1); + TVM_FFI_ICHECK(desc.data.dims[i] == desc.data.padded_dims[i]); desc.data.dims[i] = shape[i]; desc.data.padded_dims[i] = shape[i]; @@ -191,8 +191,8 @@ class TensorRequisite { TensorRequisite Crop(const dnnl::memory::dims& shape, const dnnl::memory::dims& offset) const { if (!defined()) return *this; // nothing for empty TR - ICHECK_EQ(shape.size(), t_desc_.dims().size()); - ICHECK_EQ(offset.size(), t_desc_.dims().size()); + TVM_FFI_ICHECK_EQ(shape.size(), t_desc_.dims().size()); + TVM_FFI_ICHECK_EQ(offset.size(), t_desc_.dims().size()); auto orig = std::make_shared(*this); // reinterpret memory buffer with new strides @@ -220,7 +220,7 @@ class TensorRequisite { } } - ICHECK(desc); + TVM_FFI_ICHECK(desc); return {desc, orig, true, {}, kUndefinedTid, reverse_data_flow_}; } @@ -254,8 +254,8 @@ class TensorRequisite { // If it's the same desc just return self if (desc == t_desc_) return *this; - ICHECK(t_desc_.dims() == desc.dims()) << "Requested layout is not compatible with " - "presented shape"; + TVM_FFI_ICHECK(t_desc_.dims() == desc.dims()) << "Requested layout is not compatible with " + "presented shape"; auto orig = std::make_shared(*this); return {desc, orig, false, {}, kUndefinedTid, reverse_data_flow_}; @@ -273,7 +273,8 @@ class TensorRequisite { if (layout.find("G") != std::string::npos) return "GOI" + sparse_dims[rank - 4]; if (layout.find("O") != std::string::npos) return "OI" + sparse_dims[rank - 3]; - LOG(FATAL) << "Unknown layout " << layout << "There is no default scheme to handle it"; + TVM_FFI_THROW(InternalError) << "Unknown layout " << layout + << "There is no default scheme to handle it"; } /*! @@ -311,13 +312,13 @@ class TensorRequisite { while (it != layout_tokens.end() && it->first == -1) it++; int rank = std::distance(layout_tokens.begin(), it); while (it != layout_tokens.end()) { - ICHECK_NE(it->first, -1) << "DNNL limitation. Blocking dims should be innermost. " - << "But received layout is " << layout; + TVM_FFI_ICHECK_NE(it->first, -1) << "DNNL limitation. Blocking dims should be innermost. " + << "But received layout is " << layout; it++; } - ICHECK_EQ(layout_tokens.size(), origin_dims.size()); - ICHECK_EQ(rank, desired_logic_layout.size()) << layout; + TVM_FFI_ICHECK_EQ(layout_tokens.size(), origin_dims.size()); + TVM_FFI_ICHECK_EQ(rank, desired_logic_layout.size()) << layout; std::vector> outermost_tokens(layout_tokens.begin(), layout_tokens.begin() + rank); @@ -359,7 +360,7 @@ class TensorRequisite { auto tag = p.second; auto dim_size = origin_dims[orig_dim_idx]; auto result_dim_position = dim_position_by_tag[tag]; - ICHECK_EQ(p.first, dim_size) + TVM_FFI_ICHECK_EQ(p.first, dim_size) << "Blocking layout is not applicable to tensor with shape: " << origin_dims << ". Requested layout is " << layout; @@ -428,8 +429,8 @@ class TensorRequisite { std::vector GetConstDataLikeVec() const { auto const_data = GetConstData(); auto desc = const_data.get_desc(); - ICHECK(desc.data_type() == utils::DnnlDType()); - ICHECK(desc.dims().size() == 1); + TVM_FFI_ICHECK(desc.data_type() == utils::DnnlDType()); + TVM_FFI_ICHECK(desc.dims().size() == 1); auto size = desc.get_size() / sizeof(T); auto ptr = static_cast(const_data.get_data_handle()); @@ -440,11 +441,11 @@ class TensorRequisite { /*! \brief Get value of constant scalar tensor if possible. */ template T GetConstScalarData() const { - ICHECK(IsConstant()); - ICHECK(IsScalar()); + TVM_FFI_ICHECK(IsConstant()); + TVM_FFI_ICHECK(IsScalar()); auto const_data = GetConstData(); auto desc = const_data.get_desc(); - ICHECK(desc.data_type() == utils::DnnlDType()); + TVM_FFI_ICHECK(desc.data_type() == utils::DnnlDType()); auto ptr = static_cast(const_data.get_data_handle()); return *ptr; @@ -472,8 +473,8 @@ class TensorRequisite { mem_(const_mem), eid_(eid), reverse_data_flow_(reverse_data_flow) { - if (mem_) ICHECK(!orig_ && !reverse_data_flow_ && eid_ == kUndefinedTid); - if (eid_ != kUndefinedTid) ICHECK(!orig_); + if (mem_) TVM_FFI_ICHECK(!orig_ && !reverse_data_flow_ && eid_ == kUndefinedTid); + if (eid_ != kUndefinedTid) TVM_FFI_ICHECK(!orig_); } /* Descriptor of particular tensor */ @@ -578,7 +579,7 @@ class TensorRegistry { } // 4) Scratchpad - ICHECK(!tr.orig_ && !tr.mem_ && tr.eid_ == TensorRequisite::kUndefinedTid); + TVM_FFI_ICHECK(!tr.orig_ && !tr.mem_ && tr.eid_ == TensorRequisite::kUndefinedTid); auto idx = tmp_mem_collection_.size(); tmp_mem_collection_.push_back(tr.t_desc_); tmp_mem_mapping_[idx] = 0; // zero position tmp mem object is reserved for scratchpads @@ -605,9 +606,9 @@ class TensorRegistry { void MarkInplace(const TensorRequisite& tr, const TensorRequisite& shared) { const auto tr_id = tr.eid(); - ICHECK(tr_id != TensorRequisite::kUndefinedTid); + TVM_FFI_ICHECK(tr_id != TensorRequisite::kUndefinedTid); const auto shared_id = shared.eid(); - ICHECK(shared_id != TensorRequisite::kUndefinedTid); + TVM_FFI_ICHECK(shared_id != TensorRequisite::kUndefinedTid); eid2idx_tmp_[tr_id] = eid2idx_tmp_[shared_id]; } @@ -627,14 +628,14 @@ class TensorRegistry { return MakeArgReq(EXT_EID, idx); } default: - LOG(FATAL) << "Unknown case"; + TVM_FFI_THROW(InternalError) << "Unknown case"; } return {}; } ArgId RegisterReorder(ArgId src_ar, const dnnl::memory::desc& desc, bool reverse_data_flow, ActionQue* action) { - ICHECK(src_ar.flag_ == TMP_STORAGE || src_ar.flag_ == EXT_EID); + TVM_FFI_ICHECK(src_ar.flag_ == TMP_STORAGE || src_ar.flag_ == EXT_EID); auto src_desc = src_ar.flag_ == TMP_STORAGE ? tmp_mem_collection_[src_ar.idx_] : ext_mem_collection_[src_ar.idx_].second; @@ -694,7 +695,7 @@ class TensorRegistry { auto desc = eid_and_desc.second; auto ext_dl_tensor = ext_data_provider_(eid); - ICHECK(ext_dl_tensor->data); + TVM_FFI_ICHECK(ext_dl_tensor->data); return dnnl::memory{desc, eng_, ext_dl_tensor->data}; } } diff --git a/src/runtime/contrib/hipblas/hipblas.cc b/src/runtime/contrib/hipblas/hipblas.cc index b1b264dea72a..b2cc7331117a 100644 --- a/src/runtime/contrib/hipblas/hipblas.cc +++ b/src/runtime/contrib/hipblas/hipblas.cc @@ -130,7 +130,7 @@ void CallHipblasLt(hipblasLtHandle_t hdl, hipStream_t stream, const DLTensor* B, const DLTensor* bias, const DLTensor* C, bool transa, bool transb, void* workspace_ptr, size_t workspace_size, hipblasLtEpilogue_t epilogue) { - ICHECK(TypeEqual(A->dtype, B->dtype)); + TVM_FFI_ICHECK(TypeEqual(A->dtype, B->dtype)); // Reversed strides indicates an in-place transpose operation. transa = IsInPlaceTransposed(A) ? !transa : transa; transb = IsInPlaceTransposed(B) ? !transb : transb; @@ -240,7 +240,7 @@ void CallHipblasLt(hipblasLtHandle_t hdl, hipStream_t stream, // hipBLASLt does not seem to support batched GEMM with one of matrices having // one batch (with batch_stride 0). - ICHECK_EQ(batch_count_A, batch_count_B); + TVM_FFI_ICHECK_EQ(batch_count_A, batch_count_B); set_batch(A_desc, batch_count_A, batch_stride_A); set_batch(B_desc, batch_count_B, batch_stride_B); @@ -279,27 +279,27 @@ inline void CallGemmEx(ffi::PackedArgs args, ffi::Any* ret, hipblasHandle_t hdl) auto C = args[2].cast(); bool transa = args[3].cast(); bool transb = args[4].cast(); - ICHECK_EQ(A->ndim, 2); - ICHECK_EQ(B->ndim, 2); - ICHECK_EQ(C->ndim, 2); + TVM_FFI_ICHECK_EQ(A->ndim, 2); + TVM_FFI_ICHECK_EQ(B->ndim, 2); + TVM_FFI_ICHECK_EQ(C->ndim, 2); - ICHECK_EQ(ElementStride(A), 1); - ICHECK_EQ(ElementStride(B), 1); - ICHECK_EQ(ElementStride(C), 1); + TVM_FFI_ICHECK_EQ(ElementStride(A), 1); + TVM_FFI_ICHECK_EQ(ElementStride(B), 1); + TVM_FFI_ICHECK_EQ(ElementStride(C), 1); - ICHECK(TypeEqual(A->dtype, B->dtype)); + TVM_FFI_ICHECK(TypeEqual(A->dtype, B->dtype)); // C can never be transposed. - ICHECK(!IsInPlaceTransposed(C)); + TVM_FFI_ICHECK(!IsInPlaceTransposed(C)); // Reversed strides indicates an in-place transpose operation. transa = IsInPlaceTransposed(A) ? !transa : transa; transb = IsInPlaceTransposed(B) ? !transb : transb; - ICHECK(CheckMixPrecisionType(A->dtype, C->dtype)) << "Unsupported data type"; - ICHECK(!TypeMatch(A->dtype, kDLInt, 8) || ColumnStride(A) % 4 == 0) + TVM_FFI_ICHECK(CheckMixPrecisionType(A->dtype, C->dtype)) << "Unsupported data type"; + TVM_FFI_ICHECK(!TypeMatch(A->dtype, kDLInt, 8) || ColumnStride(A) % 4 == 0) << "leading dimension must divide 4 for int8 gemm"; - ICHECK(!TypeMatch(B->dtype, kDLInt, 8) || ColumnStride(B) % 4 == 0) + TVM_FFI_ICHECK(!TypeMatch(B->dtype, kDLInt, 8) || ColumnStride(B) % 4 == 0) << "leading dimension must divide 4 for int8 gemm"; double alpha = args.size() > 5 ? args[5].cast() : 1.0; double beta = args.size() > 6 ? args[6].cast() : 0.0; @@ -337,28 +337,28 @@ inline void CallBatchGemmEx(ffi::PackedArgs args, ffi::Any* ret, hipblasHandle_t auto C = args[2].cast(); bool transa = args[3].cast(); bool transb = args[4].cast(); - ICHECK_EQ(A->ndim, 3); - ICHECK_EQ(B->ndim, 3); - ICHECK_EQ(C->ndim, 3); + TVM_FFI_ICHECK_EQ(A->ndim, 3); + TVM_FFI_ICHECK_EQ(B->ndim, 3); + TVM_FFI_ICHECK_EQ(C->ndim, 3); int batch_size = BatchCount3D(C); - ICHECK_EQ(ElementStride3D(A), 1); - ICHECK_EQ(ElementStride3D(B), 1); - ICHECK_EQ(ElementStride3D(C), 1); + TVM_FFI_ICHECK_EQ(ElementStride3D(A), 1); + TVM_FFI_ICHECK_EQ(ElementStride3D(B), 1); + TVM_FFI_ICHECK_EQ(ElementStride3D(C), 1); - ICHECK(TypeEqual(A->dtype, B->dtype)); + TVM_FFI_ICHECK(TypeEqual(A->dtype, B->dtype)); // C can never be transposed. - ICHECK(!IsInPlaceTransposed3D(C)); + TVM_FFI_ICHECK(!IsInPlaceTransposed3D(C)); // Reversed strides indicates an in-place transpose operation. transa = IsInPlaceTransposed3D(A) ? !transa : transa; transb = IsInPlaceTransposed3D(B) ? !transb : transb; - ICHECK(CheckMixPrecisionType(A->dtype, C->dtype, true)) << "Unsupported data type"; - ICHECK(!TypeMatch(A->dtype, kDLInt, 8) || ColumnStride3D(A) % 4 == 0) + TVM_FFI_ICHECK(CheckMixPrecisionType(A->dtype, C->dtype, true)) << "Unsupported data type"; + TVM_FFI_ICHECK(!TypeMatch(A->dtype, kDLInt, 8) || ColumnStride3D(A) % 4 == 0) << "leading dimension must divide 4 for int8 gemm"; - ICHECK(!TypeMatch(B->dtype, kDLInt, 8) || ColumnStride3D(B) % 4 == 0) + TVM_FFI_ICHECK(!TypeMatch(B->dtype, kDLInt, 8) || ColumnStride3D(B) % 4 == 0) << "leading dimension must divide 4 for int8 gemm"; double alpha = args.size() > 5 ? args[5].cast() : 1.0; double beta = args.size() > 6 ? args[6].cast() : 0.0; @@ -377,8 +377,8 @@ inline void CallBatchGemmEx(ffi::PackedArgs args, ffi::Any* ret, hipblasHandle_t B_stride = 0; } } else { - ICHECK_EQ(batch_size_a, batch_size); - ICHECK_EQ(batch_size_b, batch_size); + TVM_FFI_ICHECK_EQ(batch_size_a, batch_size); + TVM_FFI_ICHECK_EQ(batch_size_b, batch_size); } hipblasDatatype_t hip_in_type = GetHipBlasDataType(A->dtype); @@ -419,9 +419,9 @@ TVM_FFI_STATIC_INIT_BLOCK() { HipBlasThreadEntry* entry_ptr = HipBlasThreadEntry::ThreadLocal(A->device); if (TypeEqual(A->dtype, C->dtype)) { - ICHECK(TypeMatch(A->dtype, kDLFloat, 16) || - TypeMatch(A->dtype, kDLFloat, 32) || - TypeMatch(A->dtype, kDLFloat, 64)); + TVM_FFI_ICHECK(TypeMatch(A->dtype, kDLFloat, 16) || + TypeMatch(A->dtype, kDLFloat, 32) || + TypeMatch(A->dtype, kDLFloat, 64)); if (TypeMatch(A->dtype, kDLFloat, 16)) { CallGemm(args, ret, HipblasHgemmOp(entry_ptr->handle)); @@ -441,8 +441,8 @@ TVM_FFI_STATIC_INIT_BLOCK() { HipBlasThreadEntry* entry_ptr = HipBlasThreadEntry::ThreadLocal(A->device); if (TypeEqual(A->dtype, C->dtype)) { - ICHECK(TypeMatch(A->dtype, kDLFloat, 16) || TypeMatch(A->dtype, kDLFloat, 32) || - TypeMatch(A->dtype, kDLFloat, 64)); + TVM_FFI_ICHECK(TypeMatch(A->dtype, kDLFloat, 16) || TypeMatch(A->dtype, kDLFloat, 32) || + TypeMatch(A->dtype, kDLFloat, 64)); if (TypeMatch(A->dtype, kDLFloat, 16)) { CallBatchGemm(args, ret, HipblasHgemmBatchOp(entry_ptr->handle)); diff --git a/src/runtime/contrib/hipblas/hipblas_json_runtime.cc b/src/runtime/contrib/hipblas/hipblas_json_runtime.cc index 45bfabc277cc..d0abfd61e382 100644 --- a/src/runtime/contrib/hipblas/hipblas_json_runtime.cc +++ b/src/runtime/contrib/hipblas/hipblas_json_runtime.cc @@ -56,7 +56,7 @@ class HipblasJSONRuntime : public JSONRuntimeBase { ObjectPtr sptr_to_self = ffi::GetObjectPtr(this); if (this->symbol_name_ == name) { return ffi::Function([sptr_to_self, this](ffi::PackedArgs args, ffi::Any* rv) { - ICHECK(this->initialized_) << "The module has not been initialized"; + TVM_FFI_ICHECK(this->initialized_) << "The module has not been initialized"; this->Run(args); }); } else { @@ -92,9 +92,9 @@ class HipblasJSONRuntime : public JSONRuntimeBase { hipStream_t stream = static_cast(TVMFFIEnvGetStream(kDLROCM, device_id)); auto get_input = [this, &dl_tensors](const JSONGraphNode& node, int idx) { - ICHECK_LT(idx, node.GetInputs().size()); + TVM_FFI_ICHECK_LT(idx, node.GetInputs().size()); auto eid = EntryID(node.GetInputs()[idx]); - ICHECK(eid < dl_tensors.size()); + TVM_FFI_ICHECK(eid < dl_tensors.size()); return dl_tensors[eid]; }; @@ -137,7 +137,7 @@ class HipblasJSONRuntime : public JSONRuntimeBase { } } - void Run() override { LOG(FATAL) << "Unreachable"; } + void Run() override { TVM_FFI_THROW(InternalError) << "Unreachable"; } }; ffi::Module HipblasJSONRuntimeCreate(ffi::String symbol_name, ffi::String graph_json, diff --git a/src/runtime/contrib/hipblas/hipblas_utils.h b/src/runtime/contrib/hipblas/hipblas_utils.h index d07e825c21c8..ed4001ebf705 100644 --- a/src/runtime/contrib/hipblas/hipblas_utils.h +++ b/src/runtime/contrib/hipblas/hipblas_utils.h @@ -57,10 +57,11 @@ inline const char* GetHipblasErrorString(int error) { } #ifndef CHECK_HIPBLAS_ERROR -#define CHECK_HIPBLAS_ERROR(fn) \ - do { \ - int error = static_cast(fn); \ - ICHECK_EQ(error, HIPBLAS_STATUS_SUCCESS) << "HIPBLAS: " << GetHipblasErrorString(error); \ +#define CHECK_HIPBLAS_ERROR(fn) \ + do { \ + int error = static_cast(fn); \ + TVM_FFI_ICHECK_EQ(error, HIPBLAS_STATUS_SUCCESS) \ + << "HIPBLAS: " << GetHipblasErrorString(error); \ } while (0) // ; intentionally left off. #endif // CHECK_HIPBLAS_ERROR @@ -110,7 +111,7 @@ inline hipDataType GetHipDataType(DLDataType type) { return HIP_R_64F; } } - LOG(FATAL) << "Unsupported hip type"; + TVM_FFI_THROW(InternalError) << "Unsupported hip type"; } inline hipblasDatatype_t GetHipBlasDataType(DLDataType type) { @@ -138,7 +139,7 @@ inline hipblasDatatype_t GetHipBlasDataType(DLDataType type) { return HIPBLAS_R_64F; } } - LOG(FATAL) << "Unsupported hip type"; + TVM_FFI_THROW(InternalError) << "Unsupported hip type"; } /*! \brief Execute matrix multiply followed by the specified epilogue, using hipBLASLt. */ diff --git a/src/runtime/contrib/json/json_node.h b/src/runtime/contrib/json/json_node.h index 67ef3dc5b54b..c165f6b05cf3 100644 --- a/src/runtime/contrib/json/json_node.h +++ b/src/runtime/contrib/json/json_node.h @@ -71,7 +71,7 @@ class JSONGraphNodeEntry { * \brief Deserialize the json array into a node entry. */ void Load(ffi::json::Array arr) { - ICHECK_GE(arr.size(), 2) << "invalid json format"; + TVM_FFI_ICHECK_GE(arr.size(), 2) << "invalid json format"; id_ = static_cast(arr[0].cast()); index_ = static_cast(arr[1].cast()); if (arr.size() > 2) { @@ -157,7 +157,7 @@ class JSONGraphNode { dtype_ = GetAttr>("dtype"); } if (shape_.defined() && dtype_.defined()) { - ICHECK_EQ(shape_.size(), dtype_.size()); + TVM_FFI_ICHECK_EQ(shape_.size(), dtype_.size()); } } @@ -182,7 +182,7 @@ class JSONGraphNode { } else if (key == "attr" || key == "attrs") { this->LoadAttrs(kv.second.cast()); } else { - LOG(FATAL) << "Unknown key: " << key; + TVM_FFI_THROW(InternalError) << "Unknown key: " << key; } } } @@ -253,7 +253,7 @@ class JSONGraphNode { */ template T GetAttr(const std::string& key) const { - ICHECK(attrs_.count(key) > 0) << "Key: " << key << " is not found"; + TVM_FFI_ICHECK(attrs_.count(key) > 0) << "Key: " << key << " is not found"; return attrs_[key].cast(); } diff --git a/src/runtime/contrib/json/json_runtime.h b/src/runtime/contrib/json/json_runtime.h index ef41144fb914..d00d03ec89cc 100644 --- a/src/runtime/contrib/json/json_runtime.h +++ b/src/runtime/contrib/json/json_runtime.h @@ -80,7 +80,7 @@ class JSONRuntimeBase : public ffi::ModuleObj { * \param pointer to profiler */ virtual void RunProfile(profiling::Profiler* prof) { - LOG(FATAL) << "Not expected to be here : Profiling call w/o support ?"; + TVM_FFI_THROW(InternalError) << "Not expected to be here : Profiling call w/o support ?"; } /*! @@ -88,7 +88,7 @@ class JSONRuntimeBase : public ffi::ModuleObj { * \return External compiler specific debug blob */ virtual std::string DebugDump(void) { - LOG(FATAL) << "Not expected to be here : Debug dump w/o support ?"; + TVM_FFI_THROW(InternalError) << "Not expected to be here : Debug dump w/o support ?"; return ""; } @@ -108,7 +108,7 @@ class JSONRuntimeBase : public ffi::ModuleObj { [sptr_to_self, this](ffi::PackedArgs args, ffi::Any* rv) { *rv = this->const_names_; }); } else if (this->symbol_name_ == name) { return ffi::Function([sptr_to_self, this](ffi::PackedArgs args, ffi::Any* rv) { - ICHECK(this->initialized_) << "The module has not been initialized"; + TVM_FFI_ICHECK(this->initialized_) << "The module has not been initialized"; // Bind argument tensors to data entries. this->SetInputOutputBuffers(args); @@ -123,7 +123,7 @@ class JSONRuntimeBase : public ffi::ModuleObj { return ffi::Function(nullptr); } return ffi::Function([sptr_to_self, this](ffi::PackedArgs args, ffi::Any* rv) { - ICHECK(this->initialized_) << "The module has not been initialized"; + TVM_FFI_ICHECK(this->initialized_) << "The module has not been initialized"; // Bind argument tensors to data entries. this->SetInputOutputBuffers(args); @@ -143,7 +143,7 @@ class JSONRuntimeBase : public ffi::ModuleObj { } else if ("__init_" + this->symbol_name_ == name) { // The function to initialize constant tensors. return ffi::Function([sptr_to_self, this](ffi::PackedArgs args, ffi::Any* rv) { - ICHECK_EQ(args.size(), 1U); + TVM_FFI_ICHECK_EQ(args.size(), 1U); std::lock_guard guard(this->initialize_mutex_); if (!this->initialized_) { this->Init(args[0].cast>()); @@ -180,9 +180,9 @@ class JSONRuntimeBase : public ffi::ModuleObj { std::string graph_json; std::vector consts; // Load the symbol - ICHECK(stream.Read(&symbol)) << "Loading symbol name failed"; - ICHECK(stream.Read(&graph_json)) << "Loading graph json failed"; - ICHECK(stream.Read(&consts)) << "Loading the const name list failed"; + TVM_FFI_ICHECK(stream.Read(&symbol)) << "Loading symbol name failed"; + TVM_FFI_ICHECK(stream.Read(&graph_json)) << "Loading graph json failed"; + TVM_FFI_ICHECK(stream.Read(&consts)) << "Loading the const name list failed"; ffi::Array const_names; for (const auto& it : consts) { const_names.push_back(it); @@ -207,7 +207,7 @@ class JSONRuntimeBase : public ffi::ModuleObj { * \param args The packed args. */ void SetInputOutputBuffers(const ffi::PackedArgs& args) { - ICHECK_EQ(args.size(), input_var_eid_.size() + outputs_.size()) + TVM_FFI_ICHECK_EQ(args.size(), input_var_eid_.size() + outputs_.size()) << "Found mismatch in the number of provided data entryies and required."; for (size_t i = 0; i < static_cast(args.size()); i++) { @@ -242,24 +242,24 @@ class JSONRuntimeBase : public ffi::ModuleObj { uint32_t nid = input_nodes_[i]; std::string name = nodes_[nid].name_; if (nodes_[nid].op_type_ == "input") { - ICHECK_EQ(nodes_[nid].GetOpShape().size(), nodes_[nid].GetOpDataType().size()); + TVM_FFI_ICHECK_EQ(nodes_[nid].GetOpShape().size(), nodes_[nid].GetOpDataType().size()); for (size_t j = 0; j < nodes_[nid].GetOpShape().size(); ++j) { input_var_eid_.push_back(EntryID(nid, j)); } nodes_[nid].SetNumOutput(nodes_[nid].GetOpShape().size()); } else { - ICHECK_EQ(nodes_[nid].op_type_, "const"); + TVM_FFI_ICHECK_EQ(nodes_[nid].op_type_, "const"); auto pos = std::find(std::begin(const_names_), std::end(const_names_), name); - ICHECK(pos != std::end(const_names_)) << "Found non-existent constant: " << name; + TVM_FFI_ICHECK(pos != std::end(const_names_)) << "Found non-existent constant: " << name; const_idx_.push_back(nid); consts.push_back(name); } } - ICHECK_EQ(consts.size(), const_names_.size()) + TVM_FFI_ICHECK_EQ(consts.size(), const_names_.size()) << "Found mismatch for the number of constants in the graph and required."; for (size_t i = 0; i < consts.size(); i++) { - ICHECK_EQ(consts[i], const_names_[i]) + TVM_FFI_ICHECK_EQ(consts[i], const_names_[i]) << "The position of constant in the graph must be the same as the required."; } @@ -315,7 +315,7 @@ class JSONRuntimeBase : public ffi::ModuleObj { } else if (key == "symbol") { // ignored } else { - LOG(FATAL) << "Unknown key: " << key; + TVM_FFI_THROW(InternalError) << "Unknown key: " << key; } } } diff --git a/src/runtime/contrib/miopen/miopen_utils.h b/src/runtime/contrib/miopen/miopen_utils.h index 76913696b0b9..b6cee1a261b5 100644 --- a/src/runtime/contrib/miopen/miopen_utils.h +++ b/src/runtime/contrib/miopen/miopen_utils.h @@ -38,10 +38,10 @@ namespace miopen { std::string miopenGetErrorString(int error_code); -#define MIOPEN_CALL(func) \ - { \ - miopenStatus_t e = (func); \ - ICHECK_EQ(e, miopenStatusSuccess) << "miopen error: " << miopenGetErrorString(e); \ +#define MIOPEN_CALL(func) \ + { \ + miopenStatus_t e = (func); \ + TVM_FFI_ICHECK_EQ(e, miopenStatusSuccess) << "miopen error: " << miopenGetErrorString(e); \ } struct ConvEntry { diff --git a/src/runtime/contrib/miopen/softmax.cc b/src/runtime/contrib/miopen/softmax.cc index c5e467626ee8..455ad571e5b0 100644 --- a/src/runtime/contrib/miopen/softmax.cc +++ b/src/runtime/contrib/miopen/softmax.cc @@ -40,10 +40,10 @@ void softmax_impl(ffi::PackedArgs args, ffi::Any* ret, miopenSoftmaxAlgorithm_t int ndim = x->ndim; int64_t* shape = x->shape; if (axis < 0) axis += ndim; - ICHECK(axis >= 0 && axis < ndim); + TVM_FFI_ICHECK(axis >= 0 && axis < ndim); // just fp32 for now - ICHECK(TypeMatch(x->dtype, kDLFloat, 32)); - ICHECK(TypeMatch(y->dtype, kDLFloat, 32)); + TVM_FFI_ICHECK(TypeMatch(x->dtype, kDLFloat, 32)); + TVM_FFI_ICHECK(TypeMatch(y->dtype, kDLFloat, 32)); MIOpenThreadEntry* entry_ptr = MIOpenThreadEntry::ThreadLocal(x->device); diff --git a/src/runtime/contrib/mrvl/mrvl_base64.h b/src/runtime/contrib/mrvl/mrvl_base64.h index 67452597fd48..b99d8d171bc7 100644 --- a/src/runtime/contrib/mrvl/mrvl_base64.h +++ b/src/runtime/contrib/mrvl/mrvl_base64.h @@ -38,7 +38,7 @@ namespace contrib { namespace mrvl { inline size_t b64strlen(const std::string& b64str) { - ICHECK(b64str.size() % 4 == 0) << "invalid base64 encoding"; + TVM_FFI_ICHECK(b64str.size() % 4 == 0) << "invalid base64 encoding"; size_t length = b64str.size() / 4 * 3; if (b64str[b64str.size() - 2] == '=') { length -= 2; @@ -67,7 +67,7 @@ inline void b64decode(const std::string& b64str, uint8_t* ret) { } } } - ICHECK(b64strlen(b64str) == index) << "base64 decoding fails"; + TVM_FFI_ICHECK(b64strlen(b64str) == index) << "base64 decoding fails"; } } // namespace mrvl diff --git a/src/runtime/contrib/mrvl/mrvl_hw_runtime.cc b/src/runtime/contrib/mrvl/mrvl_hw_runtime.cc index 26a603c430ee..f633c88591ad 100644 --- a/src/runtime/contrib/mrvl/mrvl_hw_runtime.cc +++ b/src/runtime/contrib/mrvl/mrvl_hw_runtime.cc @@ -127,7 +127,7 @@ struct ml_dpdk_cb { void get_tvmc_callbacks(const char* so_path, ml_tvmc_cb* obj) { obj->handle = dlopen(so_path, RTLD_LAZY); if (obj->handle == nullptr) - ICHECK(false) << "Marvell-Runtime-ERROR Loading shared library failed"; + TVM_FFI_ICHECK(false) << "Marvell-Runtime-ERROR Loading shared library failed"; obj->mrvl_tvmc_ml_init = (mrvl_tvmc_ml_init_ptr)dlsym(obj->handle, "mrvl_ml_init"); obj->mrvl_tvmc_ml_finish = (mrvl_tvmc_ml_finish_ptr)dlsym(obj->handle, "mrvl_ml_finish"); @@ -176,12 +176,12 @@ class MarvellHardwareModuleNode : public ffi::ModuleObj { // Deallocate input quantize and output quantize buffer ret = dpdk_cb_.mrvl_dpdk_io_free(device_handle, run_arg.model_id, symbol_name_.c_str()); - ICHECK(ret == 0) << "IO free failed, model_id =" << run_arg.model_id; + TVM_FFI_ICHECK(ret == 0) << "IO free failed, model_id =" << run_arg.model_id; // Unload model ret = dpdk_cb_.mrvl_dpdk_glow_layer_unload(run_arg.device, run_arg.model_id, symbol_name_.c_str()); - ICHECK(ret == 0) << "Model layer unload failed, model_id =" << run_arg.model_id; + TVM_FFI_ICHECK(ret == 0) << "Model layer unload failed, model_id =" << run_arg.model_id; num_loaded--; } else { // Clean Up @@ -268,12 +268,12 @@ class MarvellHardwareModuleNode : public ffi::ModuleObj { int num_inputs, num_outputs, batch_size; // Load the symbol_name and other data to construct the module - ICHECK(stream.Read(&symbol_name)) << "Loading symbol name failed"; - ICHECK(stream.Read(&nodes_json)) << "Loading nodes json failed"; - ICHECK(stream.Read(&bin_code)) << "Loading binary code failed"; - ICHECK(stream.Read(&num_inputs)) << "Loading num_inputs failed"; - ICHECK(stream.Read(&num_outputs)) << "Loading num_outputs failed"; - ICHECK(stream.Read(&batch_size)) << "Loading batch_size failed"; + TVM_FFI_ICHECK(stream.Read(&symbol_name)) << "Loading symbol name failed"; + TVM_FFI_ICHECK(stream.Read(&nodes_json)) << "Loading nodes json failed"; + TVM_FFI_ICHECK(stream.Read(&bin_code)) << "Loading binary code failed"; + TVM_FFI_ICHECK(stream.Read(&num_inputs)) << "Loading num_inputs failed"; + TVM_FFI_ICHECK(stream.Read(&num_outputs)) << "Loading num_outputs failed"; + TVM_FFI_ICHECK(stream.Read(&batch_size)) << "Loading batch_size failed"; auto n = ffi::make_object(symbol_name, nodes_json, bin_code, num_inputs, num_outputs, batch_size); return ffi::Module(n); @@ -445,19 +445,19 @@ class MarvellHardwareModuleNode : public ffi::ModuleObj { ret = dpdk_cb_.mrvl_dpdk_glow_layer_load( run_arg.device, run_arg.model_id, symbol_name_.c_str(), reinterpret_cast(byte_array.data()), num_bytes, &run_arg.layer_idx); - ICHECK(ret == 0) << "Model layer load failed, model_id =" << run_arg.model_id; + TVM_FFI_ICHECK(ret == 0) << "Model layer load failed, model_id =" << run_arg.model_id; num_loaded++; // Allocate input quantize and output quantize buffer ret = dpdk_cb_.mrvl_dpdk_io_alloc(device_handle, run_arg.model_id, symbol_name_.c_str(), reinterpret_cast(&run_arg.i_q_buf), reinterpret_cast(&run_arg.o_q_buf)); - ICHECK(ret == 0) << "IO alloc failed, model_id =" << run_arg.model_id; + TVM_FFI_ICHECK(ret == 0) << "IO alloc failed, model_id =" << run_arg.model_id; } else { // Load the model run_arg.model_id = tvmc_cb_.mrvl_tvmc_ml_model_load(reinterpret_cast(byte_array.data()), num_bytes); - ICHECK(run_arg.model_id >= 0) << "Failed to load model!"; + TVM_FFI_ICHECK(run_arg.model_id >= 0) << "Failed to load model!"; num_loaded++; // Allocate input quantize and dequant buffer run_arg.i_q_buf = tvmc_cb_.mrvl_tvmc_ml_io_alloc(run_arg.model_id, input_quantize, nullptr); diff --git a/src/runtime/contrib/mrvl/mrvl_runtime.cc b/src/runtime/contrib/mrvl/mrvl_runtime.cc index bcc4bcc2e291..13673fe469f2 100644 --- a/src/runtime/contrib/mrvl/mrvl_runtime.cc +++ b/src/runtime/contrib/mrvl/mrvl_runtime.cc @@ -106,11 +106,12 @@ class MarvellSimulatorModuleNode : public ffi::ModuleObj { std::string nodes_json; std::string bin_code; // Load the symbol_name and other data to construct the module - ICHECK(stream.Read(&symbol_name)) + TVM_FFI_ICHECK(stream.Read(&symbol_name)) << "Marvell-Compiler-ERROR-Internal::Loading symbol name failed"; - ICHECK(stream.Read(&nodes_json)) + TVM_FFI_ICHECK(stream.Read(&nodes_json)) << "Marvell-Compiler-ERROR-Internal::Loading nodes json failed"; - ICHECK(stream.Read(&bin_code)) << "Marvell-Compiler-ERROR-Internal::Loading bin code failed"; + TVM_FFI_ICHECK(stream.Read(&bin_code)) + << "Marvell-Compiler-ERROR-Internal::Loading bin code failed"; auto n = ffi::make_object(symbol_name, nodes_json, bin_code); return ffi::Module(n); } @@ -131,7 +132,7 @@ class MarvellSimulatorModuleNode : public ffi::ModuleObj { size_t num_outputs_; void Run(ffi::PackedArgs args) { - ICHECK_EQ(args.size(), num_inputs_ + num_outputs_) + TVM_FFI_ICHECK_EQ(args.size(), num_inputs_ + num_outputs_) << "Marvell-Compiler-ERROR-Internal::Mismatch in number of input & number of output args " "to subgraph"; tvm::runtime::contrib::mrvl::RunMarvellSimulator(args, symbol_name_, bin_code_, num_inputs_, diff --git a/src/runtime/contrib/mrvl/mrvl_sw_runtime_lib.cc b/src/runtime/contrib/mrvl/mrvl_sw_runtime_lib.cc index a7d50f412c9d..5aa775f314a4 100644 --- a/src/runtime/contrib/mrvl/mrvl_sw_runtime_lib.cc +++ b/src/runtime/contrib/mrvl/mrvl_sw_runtime_lib.cc @@ -128,11 +128,11 @@ static void ReadOutputsAndUpdateRuntime(ffi::PackedArgs args, size_t num_inputs, float* data = new float[tot_dim](); ffi::String outbin = out_bin_prefix + "-" + std::to_string(out - num_inputs) + ".bin"; std::ifstream fin(outbin, std::ios::binary); - ICHECK(fin.is_open()) << "Cannot open file: " << outbin; + TVM_FFI_ICHECK(fin.is_open()) << "Cannot open file: " << outbin; int i = 0; while (fin.read(reinterpret_cast(&f), sizeof(float))) { data[i] = f; - ICHECK(i < tot_dim) << "Output data size mismatch"; + TVM_FFI_ICHECK(i < tot_dim) << "Output data size mismatch"; i++; } arr.CopyFromBytes(data, tot_dim * sizeof(float)); @@ -157,8 +157,8 @@ void tvm::runtime::contrib::mrvl::RunMarvellSimulator(ffi::PackedArgs args, const auto search_path = tvm::ffi::Function::GetGlobal("tvm.mrvl.SearchPath"); std::string tools_directory = (*search_path)(file_name); if (tools_directory.empty()) { - ICHECK(false) << "mrvl-mlsim simulator not found! Please specify the path to Marvell " - "tools by adding it to $PATH."; + TVM_FFI_ICHECK(false) << "mrvl-mlsim simulator not found! Please specify the path to Marvell " + "tools by adding it to $PATH."; } const auto temp_dir = tvm::ffi::Function::GetGlobal("tvm.mrvl.TempDir"); diff --git a/src/runtime/contrib/msc/tensorrt_runtime.cc b/src/runtime/contrib/msc/tensorrt_runtime.cc index 649da3f9fd6a..e44ed1df0390 100644 --- a/src/runtime/contrib/msc/tensorrt_runtime.cc +++ b/src/runtime/contrib/msc/tensorrt_runtime.cc @@ -87,7 +87,7 @@ class MSCTensorRTRuntime : public JSONRuntimeBase { * \param consts The constant params from compiled model. */ void Init(const ffi::Array& consts) override { - ICHECK_EQ(consts.size(), const_idx_.size()) + TVM_FFI_ICHECK_EQ(consts.size(), const_idx_.size()) << "The number of input constants must match the number of required."; LoadGlobalOptions(); for (size_t nid = 0; nid < nodes_.size(); nid++) { @@ -120,7 +120,7 @@ class MSCTensorRTRuntime : public JSONRuntimeBase { SetInputOutputBinds(); if (tool_tag_.size() > 0) { const auto pf = tvm::ffi::Function::GetGlobal("msc_tool.callback_step"); - ICHECK(pf.has_value()) << "Cannot find msc_tool.callback_step func."; + TVM_FFI_ICHECK(pf.has_value()) << "Cannot find msc_tool.callback_step func."; ffi::Map input_datas; int device_id = 0; for (const auto& pair : input_bindings_) { @@ -134,7 +134,7 @@ class MSCTensorRTRuntime : public JSONRuntimeBase { } auto tvm_stream = TVMFFIEnvGetStream(kDLCUDA, device_id); #if TRT_VERSION_GE(6, 0, 1) - ICHECK(context_->enqueueV2(bindings_.data(), tvm_stream, nullptr)) + TVM_FFI_ICHECK(context_->enqueueV2(bindings_.data(), tvm_stream, nullptr)) << "Running TensorRT failed."; #else LOG_FATAL << "Only support tensorrt with version >=6.0.0"; @@ -145,7 +145,7 @@ class MSCTensorRTRuntime : public JSONRuntimeBase { uint32_t eid = EntryID(outputs_[i]); const auto& name = nodes_[nid].GetOpName() + ":" + std::to_string(outputs_[i].index_); int binding_index = engine_->getBindingIndex(name.c_str()); - ICHECK_NE(binding_index, -1); + TVM_FFI_ICHECK_NE(binding_index, -1); if (data_entry_[eid]->device.device_type != kDLCUDA || tool_tag_.size() > 0) { auto device_buffer = GetOrAllocateDeviceBuffer(eid, binding_index); device_buffer.CopyTo(const_cast(data_entry_[eid])); @@ -153,7 +153,7 @@ class MSCTensorRTRuntime : public JSONRuntimeBase { } if (tool_tag_.size() > 0) { const auto pf = tvm::ffi::Function::GetGlobal("msc_tool.callback_step"); - ICHECK(pf.has_value()) << "Cannot find msc_tool.callback_step func."; + TVM_FFI_ICHECK(pf.has_value()) << "Cannot find msc_tool.callback_step func."; ffi::Map output_datas; for (int bid = 0; bid < engine_->getNbBindings(); bid++) { if (input_bindings_.count(bid)) { @@ -239,11 +239,11 @@ class MSCTensorRTRuntime : public JSONRuntimeBase { uint32_t eid = EntryID(nid, j); const auto& name = nodes_[nid].GetOpName() + ":" + std::to_string(j); int binding_index = engine_->getBindingIndex(name.c_str()); - ICHECK_NE(binding_index, -1); + TVM_FFI_ICHECK_NE(binding_index, -1); #if TRT_VERSION_GE(6, 0, 1) std::vector shape(data_entry_[eid]->shape, data_entry_[eid]->shape + data_entry_[eid]->ndim); - ICHECK(context_->setBindingDimensions(binding_index, VectorToTrtDims(shape))); + TVM_FFI_ICHECK(context_->setBindingDimensions(binding_index, VectorToTrtDims(shape))); #endif if (data_entry_[eid]->device.device_type == kDLCUDA && tool_tag_.size() == 0) { bindings_[binding_index] = data_entry_[eid]->data; @@ -267,7 +267,7 @@ class MSCTensorRTRuntime : public JSONRuntimeBase { uint32_t eid = EntryID(outputs_[i]); const auto& name = nodes_[nid].GetOpName() + ":" + std::to_string(outputs_[i].index_); int binding_index = engine_->getBindingIndex(name.c_str()); - ICHECK_NE(binding_index, -1); + TVM_FFI_ICHECK_NE(binding_index, -1); if (data_entry_[eid]->device.device_type == kDLCUDA && tool_tag_.size() == 0) { bindings_[binding_index] = data_entry_[eid]->data; } else { @@ -284,7 +284,8 @@ class MSCTensorRTRuntime : public JSONRuntimeBase { } if (!device_buffers_.count(bid)) { const auto& tensor_name = engine_->getBindingName(bid); - ICHECK(tensor_ids_.count(tensor_name)) << "Can not find tensor_name " << tensor_name; + TVM_FFI_ICHECK(tensor_ids_.count(tensor_name)) + << "Can not find tensor_name " << tensor_name; const auto& pair = tensor_ids_[tensor_name]; auto shape = nodes_[pair.first].GetOpShape()[pair.second]; auto dtype = nodes_[pair.first].GetOpDataType()[pair.second]; @@ -318,8 +319,8 @@ class MSCTensorRTRuntime : public JSONRuntimeBase { #else // TVM_GRAPH_EXECUTOR_TENSORRT void Run() override { - LOG(FATAL) << "TensorRT runtime is not enabled. " - << "Please build with USE_TENSORRT_RUNTIME."; + TVM_FFI_THROW(InternalError) << "TensorRT runtime is not enabled. " + << "Please build with USE_TENSORRT_RUNTIME."; } bool LoadEngine(const ffi::String& engine_file) { return false; } diff --git a/src/runtime/contrib/nnapi/nnapi_builder.cc b/src/runtime/contrib/nnapi/nnapi_builder.cc index d43f00661de9..044ff1ccd4a8 100644 --- a/src/runtime/contrib/nnapi/nnapi_builder.cc +++ b/src/runtime/contrib/nnapi/nnapi_builder.cc @@ -131,7 +131,7 @@ bool NNAPIOperand::IsDynamicShape() const { } NNAPIModelBuilder::NNAPIModelBuilder() { - ICHECK_EQ(ANeuralNetworksModel_create(&model_), ANEURALNETWORKS_NO_ERROR); + TVM_FFI_ICHECK_EQ(ANeuralNetworksModel_create(&model_), ANEURALNETWORKS_NO_ERROR); } NNAPIModelBuilder::~NNAPIModelBuilder() { ANeuralNetworksModel_free(model_); } @@ -140,11 +140,11 @@ NNAPIOperand NNAPIModelBuilder::CreateOperandWithValue(const DLTensor& tensor) { NNAPIOperand operand(next_operand_index_++, &tensor); const size_t operand_data_size = GetDataSize(tensor); - ICHECK_EQ(ANeuralNetworksModel_addOperand(model_, operand.GetOperandType().Get()), - ANEURALNETWORKS_NO_ERROR); - ICHECK_EQ(ANeuralNetworksModel_setOperandValue(model_, operand.GetOperandIndex(), tensor.data, - operand_data_size), - ANEURALNETWORKS_NO_ERROR); + TVM_FFI_ICHECK_EQ(ANeuralNetworksModel_addOperand(model_, operand.GetOperandType().Get()), + ANEURALNETWORKS_NO_ERROR); + TVM_FFI_ICHECK_EQ(ANeuralNetworksModel_setOperandValue(model_, operand.GetOperandIndex(), + tensor.data, operand_data_size), + ANEURALNETWORKS_NO_ERROR); return operand; } @@ -154,10 +154,11 @@ NNAPIOperand NNAPIModelBuilder::CreateOperandWithValue(int32_t tensor_type, int32_t zero_point, const void* buffer, size_t size) { NNAPIOperand operand(next_operand_index_++, tensor_type, dimensions, scale, zero_point); - ICHECK_EQ(ANeuralNetworksModel_addOperand(model_, operand.GetOperandType().Get()), - ANEURALNETWORKS_NO_ERROR); - ICHECK_EQ(ANeuralNetworksModel_setOperandValue(model_, operand.GetOperandIndex(), buffer, size), - ANEURALNETWORKS_NO_ERROR); + TVM_FFI_ICHECK_EQ(ANeuralNetworksModel_addOperand(model_, operand.GetOperandType().Get()), + ANEURALNETWORKS_NO_ERROR); + TVM_FFI_ICHECK_EQ( + ANeuralNetworksModel_setOperandValue(model_, operand.GetOperandIndex(), buffer, size), + ANEURALNETWORKS_NO_ERROR); return operand; } @@ -165,62 +166,64 @@ NNAPIOperand NNAPIModelBuilder::CreateScalarOperandWithValue(OperandCode operand const void* buffer, size_t size) { NNAPIOperand operand = NNAPIOperand::Scalar(next_operand_index_++, operand_code, {}, 0.0f, 0); - ICHECK_EQ(ANeuralNetworksModel_addOperand(model_, operand.GetOperandType().Get()), - ANEURALNETWORKS_NO_ERROR); - ICHECK_EQ(ANeuralNetworksModel_setOperandValue(model_, operand.GetOperandIndex(), buffer, size), - ANEURALNETWORKS_NO_ERROR); + TVM_FFI_ICHECK_EQ(ANeuralNetworksModel_addOperand(model_, operand.GetOperandType().Get()), + ANEURALNETWORKS_NO_ERROR); + TVM_FFI_ICHECK_EQ( + ANeuralNetworksModel_setOperandValue(model_, operand.GetOperandIndex(), buffer, size), + ANEURALNETWORKS_NO_ERROR); return operand; } NNAPIOperand NNAPIModelBuilder::CreateOperand(const DLTensor& tensor) { NNAPIOperand operand(next_operand_index_++, tensor.shape, tensor.ndim, tensor.dtype); - ICHECK_EQ(ANeuralNetworksModel_addOperand(model_, operand.GetOperandType().Get()), - ANEURALNETWORKS_NO_ERROR); + TVM_FFI_ICHECK_EQ(ANeuralNetworksModel_addOperand(model_, operand.GetOperandType().Get()), + ANEURALNETWORKS_NO_ERROR); return operand; } NNAPIOperand NNAPIModelBuilder::CreateOperand(const int64_t* shape, int ndim, DLDataType dtype) { NNAPIOperand operand(next_operand_index_++, shape, ndim, dtype); - ICHECK_EQ(ANeuralNetworksModel_addOperand(model_, operand.GetOperandType().Get()), - ANEURALNETWORKS_NO_ERROR); + TVM_FFI_ICHECK_EQ(ANeuralNetworksModel_addOperand(model_, operand.GetOperandType().Get()), + ANEURALNETWORKS_NO_ERROR); return operand; } NNAPIOperand NNAPIModelBuilder::CreateOperand(int32_t tensor_type, std::vector dimensions, float scale, int32_t zero_point) { NNAPIOperand operand(next_operand_index_++, tensor_type, dimensions, scale, zero_point); - ICHECK_EQ(ANeuralNetworksModel_addOperand(model_, operand.GetOperandType().Get()), - ANEURALNETWORKS_NO_ERROR); + TVM_FFI_ICHECK_EQ(ANeuralNetworksModel_addOperand(model_, operand.GetOperandType().Get()), + ANEURALNETWORKS_NO_ERROR); return operand; } void NNAPIModelBuilder::AddOperation(ANeuralNetworksOperationType operation, const std::vector input_indicies, const std::vector output_indicies) { - ICHECK_EQ(ANeuralNetworksModel_addOperation(model_, operation, input_indicies.size(), - input_indicies.data(), output_indicies.size(), - output_indicies.data()), - ANEURALNETWORKS_NO_ERROR); + TVM_FFI_ICHECK_EQ(ANeuralNetworksModel_addOperation(model_, operation, input_indicies.size(), + input_indicies.data(), output_indicies.size(), + output_indicies.data()), + ANEURALNETWORKS_NO_ERROR); } void NNAPIModelBuilder::Finish(const std::vector& model_input_operands, const std::vector& model_output_operands) { const auto model_input_indices = ExtractOperandIndices(model_input_operands); const auto model_output_indices = ExtractOperandIndices(model_output_operands); - ICHECK_EQ(ANeuralNetworksModel_identifyInputsAndOutputs( - model_, model_input_indices.size(), model_input_indices.data(), - model_output_indices.size(), model_output_indices.data()), - ANEURALNETWORKS_NO_ERROR); - ICHECK_EQ(ANeuralNetworksModel_finish(model_), ANEURALNETWORKS_NO_ERROR); + TVM_FFI_ICHECK_EQ(ANeuralNetworksModel_identifyInputsAndOutputs( + model_, model_input_indices.size(), model_input_indices.data(), + model_output_indices.size(), model_output_indices.data()), + ANEURALNETWORKS_NO_ERROR); + TVM_FFI_ICHECK_EQ(ANeuralNetworksModel_finish(model_), ANEURALNETWORKS_NO_ERROR); } ANeuralNetworksCompilation* NNAPIModelBuilder::Compile() { ANeuralNetworksCompilation* compilation; - ICHECK_EQ(ANeuralNetworksCompilation_create(model_, &compilation), ANEURALNETWORKS_NO_ERROR); - ICHECK_EQ(ANeuralNetworksCompilation_setPreference(compilation, - ANEURALNETWORKS_PREFER_FAST_SINGLE_ANSWER), - ANEURALNETWORKS_NO_ERROR); - ICHECK_EQ(ANeuralNetworksCompilation_finish(compilation), ANEURALNETWORKS_NO_ERROR); + TVM_FFI_ICHECK_EQ(ANeuralNetworksCompilation_create(model_, &compilation), + ANEURALNETWORKS_NO_ERROR); + TVM_FFI_ICHECK_EQ(ANeuralNetworksCompilation_setPreference( + compilation, ANEURALNETWORKS_PREFER_FAST_SINGLE_ANSWER), + ANEURALNETWORKS_NO_ERROR); + TVM_FFI_ICHECK_EQ(ANeuralNetworksCompilation_finish(compilation), ANEURALNETWORKS_NO_ERROR); return compilation; } @@ -229,13 +232,14 @@ int32_t TensorTypeFromDLDataType(DLDataType ty) { if (ty.bits == 32) { return ANEURALNETWORKS_TENSOR_INT32; } else { - ICHECK(false) << "Unsupported bit width " << ty.bits << " for NNAPI integer tensor"; + TVM_FFI_ICHECK(false) << "Unsupported bit width " << ty.bits << " for NNAPI integer tensor"; } } else if (ty.code == kDLUInt) { if (ty.bits == 1) { return ANEURALNETWORKS_TENSOR_BOOL8; } else { - ICHECK(false) << "Unsupported bit width " << ty.bits << " for NNAPI unsigned integer tensor"; + TVM_FFI_ICHECK(false) << "Unsupported bit width " << ty.bits + << " for NNAPI unsigned integer tensor"; } } else if (ty.code == kDLFloat) { if (ty.bits == 32) { @@ -243,10 +247,10 @@ int32_t TensorTypeFromDLDataType(DLDataType ty) { } else if (ty.bits == 16) { return ANEURALNETWORKS_TENSOR_FLOAT16; } else { - ICHECK(false) << "Unsupported bit width " << ty.bits << " for NNAPI integer tensor"; + TVM_FFI_ICHECK(false) << "Unsupported bit width " << ty.bits << " for NNAPI integer tensor"; } } else { - ICHECK(false) << "Unsupported DLDataTypeCode for NNAPI: " << ty.code; + TVM_FFI_ICHECK(false) << "Unsupported DLDataTypeCode for NNAPI: " << ty.code; } } diff --git a/src/runtime/contrib/nnapi/nnapi_ops.cc b/src/runtime/contrib/nnapi/nnapi_ops.cc index e016572acda2..a6b5a9c221a7 100644 --- a/src/runtime/contrib/nnapi/nnapi_ops.cc +++ b/src/runtime/contrib/nnapi/nnapi_ops.cc @@ -62,12 +62,12 @@ void ElwBinaryOpConverter::Convert(NNAPIModelBuilder& builder, const JSONGraphNo }; auto it = op_map.find(op_name_); - ICHECK(it != op_map.end()) << "Unsupported binary operation type " << op_name_; + TVM_FFI_ICHECK(it != op_map.end()) << "Unsupported binary operation type " << op_name_; const ANeuralNetworksOperationType operation_type = std::get<0>(it->second); const bool requires_fuse_code = std::get<1>(it->second); - ICHECK_EQ(inputs.size(), 2) << "Expected binary operation to have 2 inputs but got " - << inputs.size(); + TVM_FFI_ICHECK_EQ(inputs.size(), 2) + << "Expected binary operation to have 2 inputs but got " << inputs.size(); auto input_indices = ExtractOperandIndices(inputs); const auto output_indices = ExtractOperandIndices(outputs); @@ -101,7 +101,7 @@ void UnaryOpConverter::Convert(NNAPIModelBuilder& builder, const JSONGraphNode& // clang-format on }; auto it = op_map.find(op_name_); - ICHECK(it != op_map.end()) << "Unsupported unary operation type " << op_name_; + TVM_FFI_ICHECK(it != op_map.end()) << "Unsupported unary operation type " << op_name_; const ANeuralNetworksOperationType operation_type = it->second; const auto input_indices = ExtractOperandIndices(inputs); @@ -112,8 +112,8 @@ void UnaryOpConverter::Convert(NNAPIModelBuilder& builder, const JSONGraphNode& void SoftmaxOpConverter::Convert(NNAPIModelBuilder& builder, const JSONGraphNode& node, const std::vector& inputs, std::vector& outputs) const { - ICHECK_EQ(inputs.size(), 1) << "Unsupported number of inputs for NNAPI softmax operation: " - << inputs.size(); + TVM_FFI_ICHECK_EQ(inputs.size(), 1) + << "Unsupported number of inputs for NNAPI softmax operation: " << inputs.size(); auto input_indices = ExtractOperandIndices(inputs); const auto output_indices = ExtractOperandIndices(outputs); @@ -121,7 +121,7 @@ void SoftmaxOpConverter::Convert(NNAPIModelBuilder& builder, const JSONGraphNode // Add the scalar input for beta value at index 1. const auto& input = inputs[0]; // TODO(PLLab): Conditionally use float16 beta for float16 input. - ICHECK_EQ(input.GetTensorType(), ANEURALNETWORKS_TENSOR_FLOAT32) + TVM_FFI_ICHECK_EQ(input.GetTensorType(), ANEURALNETWORKS_TENSOR_FLOAT32) << "NNAPI runtime does not support non-float32 inputs for softmax yet"; const float beta = 1.0f; const NNAPIOperand beta_operand = @@ -180,7 +180,7 @@ NNAPIOperand TransposeOperand(NNAPIModelBuilder& builder, const NNAPIOperand& op void MatmulOpConverter::Convert(NNAPIModelBuilder& builder, const JSONGraphNode& node, const std::vector& inputs, std::vector& outputs) const { - ICHECK_EQ(inputs.size(), 2); + TVM_FFI_ICHECK_EQ(inputs.size(), 2); auto input_indices = ExtractOperandIndices(inputs); const auto output_indices = ExtractOperandIndices(outputs); @@ -192,7 +192,7 @@ void MatmulOpConverter::Convert(NNAPIModelBuilder& builder, const JSONGraphNode& // Check that the extra leading dimensions on input 0 are all ones. const size_t diff = input0_ndim - input1_ndim; for (size_t i = 0; i < diff; ++i) { - ICHECK_EQ(inputs[0].GetDimensions()[i], 1); + TVM_FFI_ICHECK_EQ(inputs[0].GetDimensions()[i], 1); } // Expand input 1's dimensions. @@ -206,7 +206,7 @@ void MatmulOpConverter::Convert(NNAPIModelBuilder& builder, const JSONGraphNode& // Check that the extra leading dimensions on input 1 are all ones. const size_t diff = input1_ndim - input0_ndim; for (size_t i = 0; i < diff; ++i) { - ICHECK_EQ(inputs[1].GetDimensions()[i], 1); + TVM_FFI_ICHECK_EQ(inputs[1].GetDimensions()[i], 1); } // Expand input 0's dimensions. @@ -238,7 +238,7 @@ void MatmulOpConverter::Convert(NNAPIModelBuilder& builder, const JSONGraphNode& void TransposeOpConverter::Convert(NNAPIModelBuilder& builder, const JSONGraphNode& node, const std::vector& inputs, std::vector& outputs) const { - ICHECK_EQ(inputs.size(), 1); + TVM_FFI_ICHECK_EQ(inputs.size(), 1); auto input_indices = ExtractOperandIndices(inputs); auto output_indices = ExtractOperandIndices(outputs); @@ -274,9 +274,9 @@ void CastOpConverter::Convert(NNAPIModelBuilder& builder, const JSONGraphNode& n // Extract the dtype attribute and check that the output operand type matches the dtype specified. const auto dtype_str = node.GetAttr("astype_dtype"); const DLDataType dtype = StringToDLDataType(std::string(dtype_str)); - ICHECK(outputs.size() == 1); + TVM_FFI_ICHECK(outputs.size() == 1); const auto output_tensor_type = outputs[0].GetTensorType(); - ICHECK(TensorTypeFromDLDataType(dtype) == output_tensor_type) + TVM_FFI_ICHECK(TensorTypeFromDLDataType(dtype) == output_tensor_type) << "Expect a cast to dtype " << dtype_str << " but got output operand of type " << output_tensor_type; @@ -301,14 +301,14 @@ void Conv2dOpConverter::Convert(NNAPIModelBuilder& builder, const JSONGraphNode& auto input_indices = ExtractOperandIndices(inputs); auto output_indices = ExtractOperandIndices(outputs); - ICHECK(inputs.size() >= 2); + TVM_FFI_ICHECK(inputs.size() >= 2); const auto input_tensor_type = inputs[0].GetTensorType(); const auto filter_tensor_type = inputs[1].GetTensorType(); - ICHECK(input_tensor_type == filter_tensor_type); - ICHECK(input_tensor_type == ANEURALNETWORKS_TENSOR_FLOAT32 || - input_tensor_type == ANEURALNETWORKS_TENSOR_FLOAT16); - ICHECK(filter_tensor_type == ANEURALNETWORKS_TENSOR_FLOAT32 || - filter_tensor_type == ANEURALNETWORKS_TENSOR_FLOAT16); + TVM_FFI_ICHECK(input_tensor_type == filter_tensor_type); + TVM_FFI_ICHECK(input_tensor_type == ANEURALNETWORKS_TENSOR_FLOAT32 || + input_tensor_type == ANEURALNETWORKS_TENSOR_FLOAT16); + TVM_FFI_ICHECK(filter_tensor_type == ANEURALNETWORKS_TENSOR_FLOAT32 || + filter_tensor_type == ANEURALNETWORKS_TENSOR_FLOAT16); // transpose kernel std::vector transposed_dimensions{0, 2, 3, 1}; @@ -347,7 +347,8 @@ void Conv2dOpConverter::Convert(NNAPIModelBuilder& builder, const JSONGraphNode& padding.push_back(static_cast(padding_attr[i])); } - ICHECK(padding.size() == 4) << "NNAPI runtime currently only supports 4-way padding for Conv2D"; + TVM_FFI_ICHECK(padding.size() == 4) + << "NNAPI runtime currently only supports 4-way padding for Conv2D"; const NNAPIOperand padding_left_operand = builder.CreateScalarOperandWithValue(ANEURALNETWORKS_INT32, &padding[1], sizeof(padding[1])); input_indices.push_back(padding_left_operand.GetOperandIndex()); @@ -371,7 +372,7 @@ void Conv2dOpConverter::Convert(NNAPIModelBuilder& builder, const JSONGraphNode& stride.push_back(static_cast(stride_attr[i])); } - ICHECK(stride.size() == 2); + TVM_FFI_ICHECK(stride.size() == 2); const NNAPIOperand stride_width_operand = builder.CreateScalarOperandWithValue(ANEURALNETWORKS_INT32, &stride[0], sizeof(stride[0])); input_indices.push_back(stride_width_operand.GetOperandIndex()); @@ -491,11 +492,11 @@ void DenseOpConverter::Convert(NNAPIModelBuilder& builder, const JSONGraphNode& auto output_indices = ExtractOperandIndices(outputs); const auto input_tensor_type = inputs[0].GetTensorType(); const auto filter_tensor_type = inputs[1].GetTensorType(); - ICHECK(input_tensor_type == filter_tensor_type); - ICHECK(input_tensor_type == ANEURALNETWORKS_TENSOR_FLOAT32 || - input_tensor_type == ANEURALNETWORKS_TENSOR_FLOAT16); - ICHECK(filter_tensor_type == ANEURALNETWORKS_TENSOR_FLOAT32 || - filter_tensor_type == ANEURALNETWORKS_TENSOR_FLOAT16); + TVM_FFI_ICHECK(input_tensor_type == filter_tensor_type); + TVM_FFI_ICHECK(input_tensor_type == ANEURALNETWORKS_TENSOR_FLOAT32 || + input_tensor_type == ANEURALNETWORKS_TENSOR_FLOAT16); + TVM_FFI_ICHECK(filter_tensor_type == ANEURALNETWORKS_TENSOR_FLOAT32 || + filter_tensor_type == ANEURALNETWORKS_TENSOR_FLOAT16); if (input_indices.size() == 2) { const int output_depth = inputs[1].GetDimensions()[0]; diff --git a/src/runtime/contrib/nnapi/nnapi_runtime.cc b/src/runtime/contrib/nnapi/nnapi_runtime.cc index 992555b7ca00..1939e90992e7 100644 --- a/src/runtime/contrib/nnapi/nnapi_runtime.cc +++ b/src/runtime/contrib/nnapi/nnapi_runtime.cc @@ -71,7 +71,7 @@ class NNAPIRuntime : public JSONRuntimeBase { std::optional compiled_model_; void Init(const ffi::Array& consts) final { - ICHECK_EQ(consts.size(), const_idx_.size()) + TVM_FFI_ICHECK_EQ(consts.size(), const_idx_.size()) << "The number of input constants must match the number of required constants."; SetupConstants(consts); CompileModel(); @@ -114,7 +114,7 @@ class NNAPIRuntime : public JSONRuntimeBase { for (size_t i = 0; i < outputs_.size(); ++i) { const auto& node = outputs_[i]; auto it = node_output_map_.find(node.id_); - ICHECK(it != node_output_map_.end()) << "Missing model output."; + TVM_FFI_ICHECK(it != node_output_map_.end()) << "Missing model output."; const auto& operand = it->second; model_output_operands.push_back(operand); } @@ -131,23 +131,25 @@ class NNAPIRuntime : public JSONRuntimeBase { const std::vector& model_output_operands) { // Execute the model. ANeuralNetworksExecution* execution; - ICHECK_EQ(ANeuralNetworksExecution_create(compilation, &execution), ANEURALNETWORKS_NO_ERROR); + TVM_FFI_ICHECK_EQ(ANeuralNetworksExecution_create(compilation, &execution), + ANEURALNETWORKS_NO_ERROR); for (size_t i = 0; i < input_nodes_.size(); ++i) { const uint32_t nid = input_nodes_[i]; if (nodes_[nid].GetOpType() == "input") { for (size_t j = 0; j < nodes_[nid].GetOpShape().size(); ++j) { auto it = node_output_map_.find(nid); - ICHECK(it != node_output_map_.end()) << "Missing model input."; + TVM_FFI_ICHECK(it != node_output_map_.end()) << "Missing model input."; const auto& operand = it->second; const uint32_t eid = EntryID(nid, j); const auto entry = data_entry_[eid]; const auto operand_data_size = GetDataSize(*entry); - ICHECK_EQ(ANeuralNetworksExecution_setInput(execution, i, operand.GetOperandType().Get(), - entry->data, operand_data_size), - ANEURALNETWORKS_NO_ERROR); + TVM_FFI_ICHECK_EQ( + ANeuralNetworksExecution_setInput(execution, i, operand.GetOperandType().Get(), + entry->data, operand_data_size), + ANEURALNETWORKS_NO_ERROR); } } } @@ -160,22 +162,23 @@ class NNAPIRuntime : public JSONRuntimeBase { const auto entry = data_entry_[eid]; const auto operand_data_size = GetDataSize(*entry); - ICHECK_EQ(ANeuralNetworksExecution_setOutput(execution, i, operand.GetOperandType().Get(), - entry->data, operand_data_size), - ANEURALNETWORKS_NO_ERROR); + TVM_FFI_ICHECK_EQ( + ANeuralNetworksExecution_setOutput(execution, i, operand.GetOperandType().Get(), + entry->data, operand_data_size), + ANEURALNETWORKS_NO_ERROR); } ANeuralNetworksEvent* compute_event; - ICHECK_EQ(ANeuralNetworksExecution_startCompute(execution, &compute_event), - ANEURALNETWORKS_NO_ERROR); - ICHECK_EQ(ANeuralNetworksEvent_wait(compute_event), ANEURALNETWORKS_NO_ERROR); + TVM_FFI_ICHECK_EQ(ANeuralNetworksExecution_startCompute(execution, &compute_event), + ANEURALNETWORKS_NO_ERROR); + TVM_FFI_ICHECK_EQ(ANeuralNetworksEvent_wait(compute_event), ANEURALNETWORKS_NO_ERROR); ANeuralNetworksEvent_free(compute_event); ANeuralNetworksExecution_free(execution); } void Run() final { - ICHECK(compiled_model_.has_value()); + TVM_FFI_ICHECK(compiled_model_.has_value()); CompiledModel& compiled_model = compiled_model_.value(); ExecuteModel(compiled_model.compilation, compiled_model.model_output_operands); } @@ -188,14 +191,14 @@ class NNAPIRuntime : public JSONRuntimeBase { // Map the op name to its converter. const auto& converter_map = GetOpConverters(); auto it = converter_map.find(node.GetOpName()); - ICHECK(it != converter_map.end()) << node.GetOpName() << ": Unsupported operation name"; + TVM_FFI_ICHECK(it != converter_map.end()) << node.GetOpName() << ": Unsupported operation name"; const NNAPIOpConverter& converter = *it->second; // Add input operands to params. for (size_t i = 0; i < node.GetInputs().size(); ++i) { auto in_node = node.GetInputs()[i]; auto it = node_output_map_.find(in_node.id_); - ICHECK(it != node_output_map_.end()) << node.GetOpName() << ": Missing input"; + TVM_FFI_ICHECK(it != node_output_map_.end()) << node.GetOpName() << ": Missing input"; auto& operand = it->second; inputs.push_back(operand); } @@ -203,9 +206,9 @@ class NNAPIRuntime : public JSONRuntimeBase { // Create and add output operands to params. const auto output_shapes = node.GetOpShape(); const auto output_dtypes = node.GetOpDataType(); - ICHECK(output_shapes.size() == output_dtypes.size()) + TVM_FFI_ICHECK(output_shapes.size() == output_dtypes.size()) << "The number of output shapes must match the number of output dtypes"; - ICHECK(output_shapes.size() == 1) + TVM_FFI_ICHECK(output_shapes.size() == 1) << "NNAPI runtime currently does not support more than one output per operation yet"; for (size_t i = 0; i < output_shapes.size(); ++i) { @@ -227,11 +230,13 @@ class NNAPIRuntime : public JSONRuntimeBase { #else // ifdef TVM_GRAPH_EXECUTOR_NNAPI void Init(const ffi::Array& consts) final { - LOG(FATAL) << "NNAPI runtime is not enabled. Build with USE_NNAPI_RUNTIME to enable it."; + TVM_FFI_THROW(InternalError) + << "NNAPI runtime is not enabled. Build with USE_NNAPI_RUNTIME to enable it."; } void Run() final { - LOG(FATAL) << "NNAPI runtime is not enabled. Build with USE_NNAPI_RUNTIME to enable it."; + TVM_FFI_THROW(InternalError) + << "NNAPI runtime is not enabled. Build with USE_NNAPI_RUNTIME to enable it."; } #endif // ifdef TVM_GRAPH_EXECUTOR_NNAPI }; diff --git a/src/runtime/contrib/nvshmem/init.cc b/src/runtime/contrib/nvshmem/init.cc index 9269c8bdd001..1528f03d8e49 100644 --- a/src/runtime/contrib/nvshmem/init.cc +++ b/src/runtime/contrib/nvshmem/init.cc @@ -48,9 +48,9 @@ void InitNVSHMEM(ffi::Shape uid_64, int num_workers, int worker_id_start) { } else { worker_id = worker_id_start + worker->worker_id; } - CHECK_EQ(uid_64.size(), UNIQUEID_PADDING + 1) - << "ValueError: The length of unique_id must be " << UNIQUEID_PADDING << ", but got " - << uid_64.size() << "."; + TVM_FFI_CHECK_EQ(uid_64.size(), UNIQUEID_PADDING + 1, ValueError) + << "The length of unique_id must be " << UNIQUEID_PADDING << ", but got " << uid_64.size() + << "."; nvshmemx_init_attr_t attr = NVSHMEMX_INIT_ATTR_INITIALIZER; @@ -69,8 +69,8 @@ void InitNVSHMEM(ffi::Shape uid_64, int num_workers, int worker_id_start) { if (worker->default_device.device_type == DLDeviceType::kDLCPU) { worker->default_device = Device{DLDeviceType::kDLCUDA, mype_node}; } else { - ICHECK(worker->default_device.device_type == DLDeviceType::kDLCUDA && - worker->default_device.device_id == mype_node) + TVM_FFI_ICHECK(worker->default_device.device_type == DLDeviceType::kDLCUDA && + worker->default_device.device_id == mype_node) << "The default device of the worker is inconsistent with the device used for NVSHMEM. " << "The default device is " << worker->default_device << ", but the device used for NVSHMEM is " << Device{DLDeviceType::kDLCUDA, mype_node} @@ -86,10 +86,10 @@ void InitNVSHMEMWrapper(ffi::String args) { ffi::String err; json::Value v = json::Parse(args, &err); if (!err.empty()) { - LOG(FATAL) << "JSON parse error: " << err; + TVM_FFI_THROW(InternalError) << "JSON parse error: " << err; } - CHECK(v.as()) << "JSON is not an object"; + TVM_FFI_ICHECK(v.as()) << "JSON is not an object"; json::Object obj = v.cast(); json::Array uid_array = obj["uid"].cast(); diff --git a/src/runtime/contrib/nvshmem/memory_allocator.cc b/src/runtime/contrib/nvshmem/memory_allocator.cc index c53935f8bc94..21ea448b2233 100644 --- a/src/runtime/contrib/nvshmem/memory_allocator.cc +++ b/src/runtime/contrib/nvshmem/memory_allocator.cc @@ -75,10 +75,11 @@ class NVSHMEMAllocator final : public PooledAllocator { private: void* DeviceAllocDataSpace(Device dev, size_t size, size_t alignment, DLDataType type_hint) final { - ICHECK_EQ(dev.device_type, DLDeviceType::kDLCUDA) + TVM_FFI_ICHECK_EQ(dev.device_type, DLDeviceType::kDLCUDA) << "nvshmem can only allocate CUDA device memory space."; - ICHECK(type_hint.code == DLDataTypeCode::kDLInt || type_hint.code == DLDataTypeCode::kDLUInt || - type_hint.code == DLDataTypeCode::kDLFloat) + TVM_FFI_ICHECK(type_hint.code == DLDataTypeCode::kDLInt || + type_hint.code == DLDataTypeCode::kDLUInt || + type_hint.code == DLDataTypeCode::kDLFloat) << "nvshmem can only allocate tensor with int, usingned int or float data types."; return nvshmem_align(alignment, size); } diff --git a/src/runtime/contrib/papi/papi.cc b/src/runtime/contrib/papi/papi.cc index 917fe1930ed8..5e334d87427f 100644 --- a/src/runtime/contrib/papi/papi.cc +++ b/src/runtime/contrib/papi/papi.cc @@ -27,13 +27,13 @@ namespace tvm { namespace runtime { namespace profiling { -#define PAPI_CALL(func) \ - { \ - int e = (func); \ - if (e != PAPI_OK) { \ - LOG(FATAL) << "PAPIError: in function " #func " " << e << " " \ - << std::string(PAPI_strerror(e)); \ - } \ +#define PAPI_CALL(func) \ + { \ + int e = (func); \ + if (e != PAPI_OK) { \ + TVM_FFI_THROW(PAPIError) << "in function " #func " " << e << " " \ + << std::string(PAPI_strerror(e)); \ + } \ } static const std::unordered_map> default_metric_names = { @@ -77,10 +77,11 @@ int component_for_device(Device dev) { } int cidx = PAPI_get_component_index(component_name.c_str()); if (cidx < 0) { - LOG(FATAL) << "Cannot find PAPI component \"" << component_name - << "\". Maybe you need to build PAPI with support for this component (use " - "`./configure --with-components=" - << component_name << "`)."; + TVM_FFI_THROW(InternalError) + << "Cannot find PAPI component \"" << component_name + << "\". Maybe you need to build PAPI with support for this component (use " + "`./configure --with-components=" + << component_name << "`)."; } return cidx; } @@ -118,7 +119,7 @@ struct PAPIMetricCollectorNode final : public MetricCollectorNode { LOG(WARNING) << "PAPI's long_long is larger than int64_t. Overflow may occur when " "reporting metrics."; } - CHECK_EQ(PAPI_library_init(PAPI_VER_CURRENT), PAPI_VER_CURRENT) + TVM_FFI_ICHECK_EQ(PAPI_library_init(PAPI_VER_CURRENT), PAPI_VER_CURRENT) << "Error while initializing PAPI"; } @@ -196,8 +197,8 @@ struct PAPIMetricCollectorNode final : public MetricCollectorNode { for (auto metric : metric_names) { int e = PAPI_add_named_event(event_set, metric.c_str()); if (e != PAPI_OK) { - LOG(FATAL) << "PAPIError: " << e << " " << std::string(PAPI_strerror(e)) << ": " << metric - << "."; + TVM_FFI_THROW(PAPIError) + << e << " " << std::string(PAPI_strerror(e)) << ": " << metric << "."; } } // Because we may have multiple calls in flight at the same time, we diff --git a/src/runtime/contrib/random/mt_random_engine.cc b/src/runtime/contrib/random/mt_random_engine.cc index 0158a66be5dd..496986905cb3 100644 --- a/src/runtime/contrib/random/mt_random_engine.cc +++ b/src/runtime/contrib/random/mt_random_engine.cc @@ -74,8 +74,8 @@ class RandomEngine { * \brief Fills a tensor with values drawn from Unif(low, high) */ void SampleUniform(DLTensor* data, float low, float high) { - ICHECK_GT(high, low) << "high must be bigger than low"; - ICHECK(ffi::IsContiguous(*data)); + TVM_FFI_ICHECK_GT(high, low) << "high must be bigger than low"; + TVM_FFI_ICHECK(ffi::IsContiguous(*data)); DLDataType dtype = data->dtype; int64_t size = 1; @@ -83,14 +83,14 @@ class RandomEngine { size *= data->shape[i]; } - ICHECK(dtype.code == kDLFloat && dtype.bits == 32 && dtype.lanes == 1); + TVM_FFI_ICHECK(dtype.code == kDLFloat && dtype.bits == 32 && dtype.lanes == 1); if (data->device.device_type == kDLCPU) { std::uniform_real_distribution uniform_dist(low, high); std::generate_n(static_cast(data->data), size, [&]() { return uniform_dist(rnd_engine_); }); } else { - LOG(FATAL) << "Do not support random.uniform on this device yet"; + TVM_FFI_THROW(InternalError) << "Do not support random.uniform on this device yet"; } } @@ -98,8 +98,8 @@ class RandomEngine { * \brief Fills a tensor with values drawn from Normal(loc, scale**2) */ void SampleNormal(DLTensor* data, float loc, float scale) { - ICHECK_GT(scale, 0) << "standard deviation must be positive"; - ICHECK(ffi::IsContiguous(*data)); + TVM_FFI_ICHECK_GT(scale, 0) << "standard deviation must be positive"; + TVM_FFI_ICHECK(ffi::IsContiguous(*data)); DLDataType dtype = data->dtype; int64_t size = 1; @@ -107,14 +107,14 @@ class RandomEngine { size *= data->shape[i]; } - ICHECK(dtype.code == kDLFloat && dtype.bits == 32 && dtype.lanes == 1); + TVM_FFI_ICHECK(dtype.code == kDLFloat && dtype.bits == 32 && dtype.lanes == 1); if (data->device.device_type == kDLCPU) { std::normal_distribution normal_dist(loc, scale); std::generate_n(static_cast(data->data), size, [&]() { return normal_dist(rnd_engine_); }); } else { - LOG(FATAL) << "Do not support random.normal on this device yet"; + TVM_FFI_THROW(InternalError) << "Do not support random.normal on this device yet"; } } @@ -171,7 +171,8 @@ class RandomEngine { std::generate_n(static_cast(data) + st, ed - st, [&]() { return dist(rnd_engine_); }); } else { - LOG(FATAL) << "Doesn't support dtype code " << dtype.code << " dtype bits " << dtype.bits; + TVM_FFI_THROW(InternalError) + << "Doesn't support dtype code " << dtype.code << " dtype bits " << dtype.bits; } } @@ -185,7 +186,8 @@ class RandomEngine { dtype.bits == 32 || dtype.bits == 64) { FillDataImpl(tensor->data, 0, size, dtype); } else { - LOG(FATAL) << "Doesn't support dtype code " << dtype.code << " dtype bits " << dtype.bits; + TVM_FFI_THROW(InternalError) + << "Doesn't support dtype code " << dtype.code << " dtype bits " << dtype.bits; } } @@ -221,9 +223,10 @@ class RandomEngine { if (dtype.bits == 1 || dtype.bits == 4 || dtype.bits == 8 || dtype.bits == 16 || dtype.bits == 32 || dtype.bits == 64) { int res = TVMBackendParallelLaunch(ParallelTask::RunTask, &task, 0); - ICHECK_EQ(res, 0) << "RandomFillForMeasure: TVMBackendParallelLaunch failed"; + TVM_FFI_ICHECK_EQ(res, 0) << "RandomFillForMeasure: TVMBackendParallelLaunch failed"; } else { - LOG(FATAL) << "Doesn't support dtype code " << dtype.code << " dtype bits " << dtype.bits; + TVM_FFI_THROW(InternalError) + << "Doesn't support dtype code " << dtype.code << " dtype bits " << dtype.bits; } } diff --git a/src/runtime/contrib/random/random.cc b/src/runtime/contrib/random/random.cc index 1c75a8152a5f..97bd128a6322 100644 --- a/src/runtime/contrib/random/random.cc +++ b/src/runtime/contrib/random/random.cc @@ -30,27 +30,27 @@ #include "mt_random_engine.cc" -#define DLPACK_INTEGER_TYPE_SWITCH(type, DType, ...) \ - if (type.code == kDLInt && type.bits == 32) { \ - typedef int32_t DType; \ - { __VA_ARGS__ } \ - } else if (type.code == kDLInt && type.bits == 16) { \ - typedef int16_t DType; \ - { __VA_ARGS__ } \ - } else if (type.code == kDLInt && type.bits == 8) { \ - typedef int8_t DType; \ - { __VA_ARGS__ } \ - } else if (type.code == kDLUInt && type.bits == 32) { \ - typedef uint32_t DType; \ - { __VA_ARGS__ } \ - } else if (type.code == kDLUInt && type.bits == 16) { \ - typedef uint16_t DType; \ - { __VA_ARGS__ } \ - } else if (type.code == kDLUInt && type.bits == 8) { \ - typedef uint8_t DType; \ - { __VA_ARGS__ } \ - } else { \ - LOG(FATAL) << "unknown data type"; \ +#define DLPACK_INTEGER_TYPE_SWITCH(type, DType, ...) \ + if (type.code == kDLInt && type.bits == 32) { \ + typedef int32_t DType; \ + { __VA_ARGS__ } \ + } else if (type.code == kDLInt && type.bits == 16) { \ + typedef int16_t DType; \ + { __VA_ARGS__ } \ + } else if (type.code == kDLInt && type.bits == 8) { \ + typedef int8_t DType; \ + { __VA_ARGS__ } \ + } else if (type.code == kDLUInt && type.bits == 32) { \ + typedef uint32_t DType; \ + { __VA_ARGS__ } \ + } else if (type.code == kDLUInt && type.bits == 16) { \ + typedef uint16_t DType; \ + { __VA_ARGS__ } \ + } else if (type.code == kDLUInt && type.bits == 8) { \ + typedef uint8_t DType; \ + { __VA_ARGS__ } \ + } else { \ + TVM_FFI_THROW(InternalError) << "unknown data type"; \ } namespace tvm { @@ -77,8 +77,8 @@ TVM_FFI_STATIC_INIT_BLOCK() { int64_t low = args[0].cast(); int64_t high = args[1].cast(); auto out = args[2].cast(); - ICHECK_GT(high, low) << "high must be bigger than low"; - ICHECK(ffi::IsContiguous(*out)); + TVM_FFI_ICHECK_GT(high, low) << "high must be bigger than low"; + TVM_FFI_ICHECK(ffi::IsContiguous(*out)); DLDataType dtype = out->dtype; int64_t size = 1; @@ -100,7 +100,8 @@ TVM_FFI_STATIC_INIT_BLOCK() { return low + rint % (high - low); }); } else { - LOG(FATAL) << "Do not support random.randint on this device yet"; + TVM_FFI_THROW(InternalError) + << "Do not support random.randint on this device yet"; } }) }) diff --git a/src/runtime/contrib/rocblas/rocblas.cc b/src/runtime/contrib/rocblas/rocblas.cc index 62ae5b27a3cd..fa03fd0245e7 100644 --- a/src/runtime/contrib/rocblas/rocblas.cc +++ b/src/runtime/contrib/rocblas/rocblas.cc @@ -80,15 +80,15 @@ TVM_FFI_STATIC_INIT_BLOCK() { bool transa = args[3].cast(); bool transb = args[4].cast(); // call gemm for simple compact code. - ICHECK_EQ(A->ndim, 2); - ICHECK_EQ(B->ndim, 2); - ICHECK_EQ(C->ndim, 2); - ICHECK(ffi::IsContiguous(*C)); - ICHECK(ffi::IsContiguous(*B)); - ICHECK(ffi::IsContiguous(*A)); - ICHECK(TypeMatch(A->dtype, kDLFloat, 32)); - ICHECK(TypeMatch(B->dtype, kDLFloat, 32)); - ICHECK(TypeMatch(C->dtype, kDLFloat, 32)); + TVM_FFI_ICHECK_EQ(A->ndim, 2); + TVM_FFI_ICHECK_EQ(B->ndim, 2); + TVM_FFI_ICHECK_EQ(C->ndim, 2); + TVM_FFI_ICHECK(ffi::IsContiguous(*C)); + TVM_FFI_ICHECK(ffi::IsContiguous(*B)); + TVM_FFI_ICHECK(ffi::IsContiguous(*A)); + TVM_FFI_ICHECK(TypeMatch(A->dtype, kDLFloat, 32)); + TVM_FFI_ICHECK(TypeMatch(B->dtype, kDLFloat, 32)); + TVM_FFI_ICHECK(TypeMatch(C->dtype, kDLFloat, 32)); float alpha = 1.0; float beta = 0.0; @@ -118,12 +118,12 @@ TVM_FFI_STATIC_INIT_BLOCK() { bool transa = args[3].cast(); bool transb = args[4].cast(); // call gemm for simple compact code. - ICHECK_EQ(A->ndim, 3); - ICHECK_EQ(B->ndim, 3); - ICHECK_EQ(C->ndim, 3); - ICHECK(TypeMatch(A->dtype, kDLFloat, 32)); - ICHECK(TypeMatch(B->dtype, kDLFloat, 32)); - ICHECK(TypeMatch(C->dtype, kDLFloat, 32)); + TVM_FFI_ICHECK_EQ(A->ndim, 3); + TVM_FFI_ICHECK_EQ(B->ndim, 3); + TVM_FFI_ICHECK_EQ(C->ndim, 3); + TVM_FFI_ICHECK(TypeMatch(A->dtype, kDLFloat, 32)); + TVM_FFI_ICHECK(TypeMatch(B->dtype, kDLFloat, 32)); + TVM_FFI_ICHECK(TypeMatch(C->dtype, kDLFloat, 32)); float alpha = 1.0; float beta = 0.0; diff --git a/src/runtime/contrib/sort/sort.cc b/src/runtime/contrib/sort/sort.cc index afbac3a84701..541548d18250 100644 --- a/src/runtime/contrib/sort/sort.cc +++ b/src/runtime/contrib/sort/sort.cc @@ -102,15 +102,15 @@ void RegisterArgsortNMS() { } // Currently only supports input dtype to be float32. - ICHECK_EQ(dtype.code, 2) << "Currently only supports input dtype " - "to be float."; + TVM_FFI_ICHECK_EQ(dtype.code, 2) << "Currently only supports input dtype " + "to be float."; #if (__ARM_FEATURE_FP16_SCALAR_ARITHMETIC != 1) - ICHECK_EQ(dtype.bits, 32) << "Currently only supports input dtype " - "to be float32."; + TVM_FFI_ICHECK_EQ(dtype.bits, 32) << "Currently only supports input dtype " + "to be float32."; #endif - ICHECK_LT(axis, input->ndim) << "Axis out of boundary for " - "input ndim " - << input->ndim; + TVM_FFI_ICHECK_LT(axis, input->ndim) << "Axis out of boundary for " + "input ndim " + << input->ndim; for (int i = 0; i < input->ndim; ++i) { if (i < axis) { @@ -232,9 +232,9 @@ void RegisterArgsort() { if (axis < 0) { axis = input->ndim + axis; } - ICHECK_LT(axis, input->ndim) << "Axis out of boundary for " - "input ndim " - << input->ndim; + TVM_FFI_ICHECK_LT(axis, input->ndim) << "Axis out of boundary for " + "input ndim " + << input->ndim; auto data_dtype = ffi::DLDataTypeToString(input->dtype); auto out_dtype = ffi::DLDataTypeToString(output->dtype); @@ -249,7 +249,7 @@ void RegisterArgsort() { } else if (out_dtype == "float64") { argsort(input, output, axis, is_ascend); } else { - LOG(FATAL) << "Unsupported output dtype: " << out_dtype; + TVM_FFI_THROW(InternalError) << "Unsupported output dtype: " << out_dtype; } } else if (data_dtype == "float64") { if (out_dtype == "int32") { @@ -261,14 +261,14 @@ void RegisterArgsort() { } else if (out_dtype == "float64") { argsort(input, output, axis, is_ascend); } else { - LOG(FATAL) << "Unsupported output dtype: " << out_dtype; + TVM_FFI_THROW(InternalError) << "Unsupported output dtype: " << out_dtype; } #if (__ARM_FEATURE_FP16_SCALAR_ARITHMETIC == 1) } else if (data_dtype == "float16") { if (out_dtype == "float16") { argsort<__fp16, __fp16>(input, output, axis, is_ascend); } else { - LOG(FATAL) << "Unsupported output dtype: " << out_dtype; + TVM_FFI_THROW(InternalError) << "Unsupported output dtype: " << out_dtype; } #endif } else if (data_dtype == "int32") { @@ -281,7 +281,7 @@ void RegisterArgsort() { } else if (out_dtype == "float64") { argsort(input, output, axis, is_ascend); } else { - LOG(FATAL) << "Unsupported output dtype: " << out_dtype; + TVM_FFI_THROW(InternalError) << "Unsupported output dtype: " << out_dtype; } } else if (data_dtype == "int64") { if (out_dtype == "int32") { @@ -293,7 +293,7 @@ void RegisterArgsort() { } else if (out_dtype == "float64") { argsort(input, output, axis, is_ascend); } else { - LOG(FATAL) << "Unsupported output dtype: " << out_dtype; + TVM_FFI_THROW(InternalError) << "Unsupported output dtype: " << out_dtype; } } else if (data_dtype == "float16") { if (out_dtype == "int32") { @@ -305,10 +305,10 @@ void RegisterArgsort() { } else if (out_dtype == "float64") { argsort(input, output, axis, is_ascend); } else { - LOG(FATAL) << "Unsupported output dtype: " << out_dtype; + TVM_FFI_THROW(InternalError) << "Unsupported output dtype: " << out_dtype; } } else { - LOG(FATAL) << "Unsupported input dtype: " << data_dtype; + TVM_FFI_THROW(InternalError) << "Unsupported input dtype: " << data_dtype; } }); } @@ -330,14 +330,14 @@ void RegisterSort() { if (axis < 0) { axis = input->ndim + axis; } - ICHECK_LT(axis, input->ndim) << "Axis out of boundary for " - "input ndim " - << input->ndim; + TVM_FFI_ICHECK_LT(axis, input->ndim) << "Axis out of boundary for " + "input ndim " + << input->ndim; auto data_dtype = DLDataTypeToString(input->dtype); auto out_dtype = DLDataTypeToString(output->dtype); - ICHECK_EQ(data_dtype, out_dtype); + TVM_FFI_ICHECK_EQ(data_dtype, out_dtype); if (data_dtype == "float32") { sort(input, output, axis, is_ascend); @@ -354,7 +354,7 @@ void RegisterSort() { } else if (data_dtype == "float16") { sort(input, output, axis, is_ascend); } else { - LOG(FATAL) << "Unsupported input dtype: " << data_dtype; + TVM_FFI_THROW(InternalError) << "Unsupported input dtype: " << data_dtype; } }); } @@ -470,12 +470,12 @@ void RegisterTopk() { } else if (ret_type == "indices") { indices_out = args[1].cast(); } else { - LOG(FATAL) << "Unsupported ret type: " << ret_type; + TVM_FFI_THROW(InternalError) << "Unsupported ret type: " << ret_type; } if (axis < 0) { axis = input->ndim + axis; } - ICHECK(axis >= 0 && axis < input->ndim) + TVM_FFI_ICHECK(axis >= 0 && axis < input->ndim) << "Axis out of boundary for input ndim " << input->ndim; auto data_dtype = ffi::DLDataTypeToString(input->dtype); @@ -492,7 +492,7 @@ void RegisterTopk() { } else if (out_dtype == "float64") { topk(input, values_out, indices_out, k, axis, is_ascend); } else { - LOG(FATAL) << "Unsupported output dtype: " << out_dtype; + TVM_FFI_THROW(InternalError) << "Unsupported output dtype: " << out_dtype; } } else if (data_dtype == "float64") { if (out_dtype == "int32") { @@ -504,7 +504,7 @@ void RegisterTopk() { } else if (out_dtype == "float64") { topk(input, values_out, indices_out, k, axis, is_ascend); } else { - LOG(FATAL) << "Unsupported output dtype: " << out_dtype; + TVM_FFI_THROW(InternalError) << "Unsupported output dtype: " << out_dtype; } } else if (data_dtype == "uint8") { if (out_dtype == "uint8") { @@ -518,7 +518,7 @@ void RegisterTopk() { } else if (out_dtype == "float64") { topk(input, values_out, indices_out, k, axis, is_ascend); } else { - LOG(FATAL) << "Unsupported output dtype: " << out_dtype; + TVM_FFI_THROW(InternalError) << "Unsupported output dtype: " << out_dtype; } } else if (data_dtype == "int8") { if (out_dtype == "int8") { @@ -532,7 +532,7 @@ void RegisterTopk() { } else if (out_dtype == "float64") { topk(input, values_out, indices_out, k, axis, is_ascend); } else { - LOG(FATAL) << "Unsupported output dtype: " << out_dtype; + TVM_FFI_THROW(InternalError) << "Unsupported output dtype: " << out_dtype; } } else if (data_dtype == "int32") { if (out_dtype == "int32") { @@ -544,7 +544,7 @@ void RegisterTopk() { } else if (out_dtype == "float64") { topk(input, values_out, indices_out, k, axis, is_ascend); } else { - LOG(FATAL) << "Unsupported output dtype: " << out_dtype; + TVM_FFI_THROW(InternalError) << "Unsupported output dtype: " << out_dtype; } } else if (data_dtype == "int64") { if (out_dtype == "int32") { @@ -556,7 +556,7 @@ void RegisterTopk() { } else if (out_dtype == "float64") { topk(input, values_out, indices_out, k, axis, is_ascend); } else { - LOG(FATAL) << "Unsupported output dtype: " << out_dtype; + TVM_FFI_THROW(InternalError) << "Unsupported output dtype: " << out_dtype; } } else if (data_dtype == "float16") { if (out_dtype == "int32") { @@ -568,10 +568,10 @@ void RegisterTopk() { } else if (out_dtype == "float64") { topk(input, values_out, indices_out, k, axis, is_ascend); } else { - LOG(FATAL) << "Unsupported output dtype: " << out_dtype; + TVM_FFI_THROW(InternalError) << "Unsupported output dtype: " << out_dtype; } } else { - LOG(FATAL) << "Unsupported input dtype: " << data_dtype; + TVM_FFI_THROW(InternalError) << "Unsupported input dtype: " << data_dtype; } }); } diff --git a/src/runtime/contrib/tensorrt/tensorrt_builder.cc b/src/runtime/contrib/tensorrt/tensorrt_builder.cc index 3dd29857c0d7..63d886e520a7 100644 --- a/src/runtime/contrib/tensorrt/tensorrt_builder.cc +++ b/src/runtime/contrib/tensorrt/tensorrt_builder.cc @@ -72,7 +72,7 @@ TensorRTBuilder::TensorRTBuilder(TensorRTLogger* logger, } nvinfer1::DataType DLDataType2NVDataType(DLDataType data_type) { - ICHECK(data_type.code == kDLFloat && (data_type.bits == 16 || data_type.bits == 32)) + TVM_FFI_ICHECK(data_type.code == kDLFloat && (data_type.bits == 16 || data_type.bits == 32)) << "Invalid input Tensor type. Only float16 and float32 are supported"; return (data_type.bits == 16) ? nvinfer1::DataType::kHALF : nvinfer1::DataType::kFLOAT; } @@ -81,7 +81,7 @@ void TensorRTBuilder::AddInput(int nid, uint32_t entry_id, const JSONGraphNode& auto node_name = node.GetOpName(); auto shapes = node.GetOpShape(); auto dtypes = node.GetOpDataType(); - ICHECK_EQ(shapes.size(), dtypes.size()); + TVM_FFI_ICHECK_EQ(shapes.size(), dtypes.size()); node_output_map_[nid] = {}; for (size_t i = 0; i < shapes.size(); ++i) { const std::string name = node_name + "_" + std::to_string(i); @@ -106,7 +106,7 @@ void TensorRTBuilder::AddConstant(int nid, const DLTensor* data) { void TensorRTBuilder::AddOutput(const JSONGraphNodeEntry& node, uint32_t entry_id) { auto it = node_output_map_.find(node.id_); - ICHECK(it != node_output_map_.end()) << "Output was not found."; + TVM_FFI_ICHECK(it != node_output_map_.end()) << "Output was not found."; auto out_tensor = it->second[node.index_].tensor; std::string name = "tensorrt_output_" + std::to_string(network_output_names_.size()); // If the network is already marked as an input or output, make a copy to avoid TRT crash. @@ -129,23 +129,23 @@ void TensorRTBuilder::AddLayer(int nid, const JSONGraphNode& node) { const std::unordered_map>& map = GetOpConverters(); auto it = map.find(params.op_name); - ICHECK(it != map.end()) << params.op_name << ": Unsupported operator"; + TVM_FFI_ICHECK(it != map.end()) << params.op_name << ": Unsupported operator"; const TensorRTOpConverter& converter = *it->second; if (!converter.variable_input_count) { - ICHECK_EQ(node.GetInputs().size(), converter.input_types.size()) + TVM_FFI_ICHECK_EQ(node.GetInputs().size(), converter.input_types.size()) << params.op_name << ": Mismatched input sizes"; } // Get inputs. for (size_t i = 0; i < node.GetInputs().size(); ++i) { auto in_node = node.GetInputs()[i]; auto it = node_output_map_.find(in_node.id_); - ICHECK(it != node_output_map_.end()) << params.op_name << ": Input was not found"; + TVM_FFI_ICHECK(it != node_output_map_.end()) << params.op_name << ": Input was not found"; auto input = it->second[in_node.index_]; if (!converter.variable_input_count) { if (converter.input_types[i] == kTensor && input.type == kWeight) { input = TensorRTOpInput(GetInputAsTensor(input)); } else if (converter.input_types[i] == kWeight && input.type == kTensor) { - LOG(FATAL) << params.op_name << ": Input " << i << " must be a constant."; + TVM_FFI_THROW(InternalError) << params.op_name << ": Input " << i << " must be a constant."; } } params.inputs.push_back(input); @@ -157,7 +157,8 @@ void TensorRTBuilder::AddLayer(int nid, const JSONGraphNode& node) { // Get outputs. node_output_map_[nid] = {}; auto dtype = node.GetOpDataType(); - ICHECK_EQ(params.outputs.size(), dtype.size()) << params.op_name << ": Mismatched output sizes"; + TVM_FFI_ICHECK_EQ(params.outputs.size(), dtype.size()) + << params.op_name << ": Mismatched output sizes"; for (size_t i = 0; i < params.outputs.size(); ++i) { auto out = params.outputs[i]; out->setType(DLDataType2NVDataType(dtype[i])); @@ -177,7 +178,7 @@ TensorRTEngineAndContext TensorRTBuilder::BuildEngine() { if (use_int8_) { config_->setFlag(nvinfer1::BuilderFlag::kINT8); - ICHECK(calibrator_); + TVM_FFI_ICHECK(calibrator_); config_->setInt8Calibrator(calibrator_); LOG(INFO) << "config finishes setting up calibrator as INT8 mode ... "; } @@ -207,20 +208,21 @@ TensorRTEngineAndContext TensorRTBuilder::BuildEngine() { #else nvinfer1::ICudaEngine* engine = builder_->buildCudaEngine(*network_); #endif - ICHECK_EQ(engine->getNbBindings(), network_input_names_.size() + network_output_names_.size()); + TVM_FFI_ICHECK_EQ(engine->getNbBindings(), + network_input_names_.size() + network_output_names_.size()); nvinfer1::IExecutionContext* context = engine->createExecutionContext(); CleanUp(); - ICHECK(engine); - ICHECK(context); + TVM_FFI_ICHECK(engine); + TVM_FFI_ICHECK(context); return {engine, context, network_input_names_, network_output_names_}; } nvinfer1::Weights TensorRTBuilder::GetDLTensorAsWeights(const DLTensor* dptr, DLDeviceType src_device) { - ICHECK_EQ(dptr->device.device_type, src_device); - ICHECK((dptr->dtype.bits != 16 || dptr->dtype.bits != 32)) + TVM_FFI_ICHECK_EQ(dptr->device.device_type, src_device); + TVM_FFI_ICHECK((dptr->dtype.bits != 16 || dptr->dtype.bits != 32)) << "Invalid input Tensor type. Float16 and Float32 are supported"; const auto trt_dtype = (static_cast(dptr->dtype.bits) == 16) ? nvinfer1::DataType::kHALF : nvinfer1::DataType::kFLOAT; @@ -233,9 +235,9 @@ nvinfer1::Weights TensorRTBuilder::GetDLTensorAsWeights(const DLTensor* dptr, } weight.count = count; weight.values = new float[count]; - ICHECK_EQ(TVMTensorCopyToBytes(const_cast(dptr), const_cast(weight.values), - weight_bytes), - 0) + TVM_FFI_ICHECK_EQ(TVMTensorCopyToBytes(const_cast(dptr), + const_cast(weight.values), weight_bytes), + 0) << TVMGetLastError(); trt_weights_.push_back(weight); return weight; @@ -259,25 +261,25 @@ nvinfer1::ITensor* TensorRTBuilder::GetInputAsTensor(const TensorRTOpInput& inpu void TensorRTBuilder::CleanUp() { VLOG(1) << "Destroying TensorRT network"; - ICHECK(network_); + TVM_FFI_ICHECK(network_); network_->destroy(); network_ = nullptr; #if TRT_VERSION_GE(6, 0, 1) VLOG(1) << "Destroying TensorRT config"; - ICHECK(config_); + TVM_FFI_ICHECK(config_); config_->destroy(); config_ = nullptr; #endif VLOG(1) << "Destroying TensorRT builder"; - ICHECK(builder_); + TVM_FFI_ICHECK(builder_); builder_->destroy(); builder_ = nullptr; VLOG(1) << "Destroying TensorRT weights"; for (auto weight : trt_weights_) { - ICHECK(weight.values); + TVM_FFI_ICHECK(weight.values); if (weight.type == nvinfer1::DataType::kFLOAT || weight.type == nvinfer1::DataType::kHALF) { delete[] static_cast(weight.values); } else { diff --git a/src/runtime/contrib/tensorrt/tensorrt_calibrator.h b/src/runtime/contrib/tensorrt/tensorrt_calibrator.h index 491714acd927..fea1a4684df4 100755 --- a/src/runtime/contrib/tensorrt/tensorrt_calibrator.h +++ b/src/runtime/contrib/tensorrt/tensorrt_calibrator.h @@ -70,9 +70,9 @@ class TensorRTCalibrator : public nvinfer1::IInt8EntropyCalibrator2 { */ bool getBatch(void* bindings[], const char* names[], int nbBindings) noexcept override { AllocateBuffersIfNotAllocated(); - CHECK_EQ(input_names_.size(), nbBindings); + TVM_FFI_ICHECK_EQ(input_names_.size(), nbBindings); for (size_t i = 0; i < input_names_.size(); ++i) { - CHECK_EQ(input_names_[i], names[i]); + TVM_FFI_ICHECK_EQ(input_names_[i], names[i]); CUDA_CALL(cudaMemcpy(buffers_[i], data_[num_batches_calibrated_][i], batch_size_ * data_sizes_[num_batches_calibrated_][i] * sizeof(float), cudaMemcpyHostToDevice)); @@ -116,7 +116,7 @@ class TensorRTCalibrator : public nvinfer1::IInt8EntropyCalibrator2 { * entry. */ void AllocateBuffersIfNotAllocated() { if (!buffers_.empty()) return; - CHECK_GE(data_sizes_.size(), 1); + TVM_FFI_ICHECK_GE(data_sizes_.size(), 1); const int num_inputs = data_sizes_[0].size(); buffers_.assign(num_inputs, nullptr); for (int i = 0; i < num_inputs; ++i) { diff --git a/src/runtime/contrib/tensorrt/tensorrt_ops.cc b/src/runtime/contrib/tensorrt/tensorrt_ops.cc index 458789d41169..9ba6960acf64 100644 --- a/src/runtime/contrib/tensorrt/tensorrt_ops.cc +++ b/src/runtime/contrib/tensorrt/tensorrt_ops.cc @@ -51,7 +51,7 @@ nvinfer1::ITensor* TensorRTOpConverter::Reshape(TensorRTOpConverterParams* param nvinfer1::ITensor* input, const std::vector& new_shape) const { auto layer = params->network->addShuffle(*input); - ICHECK(layer != nullptr); + TVM_FFI_ICHECK(layer != nullptr); layer->setReshapeDimensions(VectorToTrtDims(new_shape)); layer->setOutputType(0, input->getType()); return layer->getOutput(0); @@ -61,17 +61,17 @@ nvinfer1::ITensor* TensorRTOpConverter::Transpose(TensorRTOpConverterParams* par nvinfer1::ITensor* input, const std::vector& order) const { auto layer = params->network->addShuffle(*input); - ICHECK(layer != nullptr); + TVM_FFI_ICHECK(layer != nullptr); nvinfer1::Permutation perm; if (TRT_HAS_IMPLICIT_BATCH(params)) { // Batch dimension cannot be modified. - ICHECK_EQ(input->getDimensions().nbDims, order.size() - 1); - ICHECK_EQ(order[0], 0); + TVM_FFI_ICHECK_EQ(input->getDimensions().nbDims, order.size() - 1); + TVM_FFI_ICHECK_EQ(order[0], 0); for (size_t i = 0; i + 1 < order.size(); ++i) { perm.order[i] = order[i + 1] - 1; } } else { - ICHECK_EQ(input->getDimensions().nbDims, order.size()); + TVM_FFI_ICHECK_EQ(input->getDimensions().nbDims, order.size()); for (size_t i = 0; i < order.size(); ++i) { perm.order[i] = order[i]; } @@ -86,11 +86,11 @@ int TensorRTOpConverter::ConvertAxis(TensorRTOpConverterParams* params, int axis if (TRT_HAS_IMPLICIT_BATCH(params)) { input_rank += 1; } - ICHECK(axis >= -input_rank && axis < input_rank); + TVM_FFI_ICHECK(axis >= -input_rank && axis < input_rank); if (axis < 0) axis += input_rank; if (TRT_HAS_IMPLICIT_BATCH(params)) { // Can't modify batch dimenson. - ICHECK_NE(axis, 0); + TVM_FFI_ICHECK_NE(axis, 0); // Subtract 1 for implicit batch dim. axis -= 1; } @@ -113,7 +113,7 @@ nvinfer1::ITensor* TensorRTOpConverter::CreateScalar( void TensorRTOpConverter::GetPadding(const ffi::Array& padding, bool* use_asymmetric_padding, nvinfer1::DimsHW* prepadding, nvinfer1::DimsHW* postpadding) const { - ICHECK(padding.size() == 1 || padding.size() == 2 || padding.size() == 4); + TVM_FFI_ICHECK(padding.size() == 1 || padding.size() == 2 || padding.size() == 4); if (padding.size() == 4) { // four int : padding width in the order of (top, left, bottom, right). *prepadding = nvinfer1::DimsHW(static_cast(padding[0]), static_cast(padding[1])); @@ -135,7 +135,7 @@ void TensorRTOpConverter::GetPadding(const ffi::Array& padding, void TensorRTOpConverter::GetPadding3D(const ffi::Array& padding, bool* use_asymmetric_padding, nvinfer1::Dims* prepadding, nvinfer1::Dims* postpadding) const { - ICHECK(padding.size() == 1 || padding.size() == 3 || padding.size() == 6); + TVM_FFI_ICHECK(padding.size() == 1 || padding.size() == 3 || padding.size() == 6); if (padding.size() == 6) { // six int : padding width in the order of (front, top, left, back, bottom, right) *prepadding = nvinfer1::Dims3(static_cast(padding[0]), static_cast(padding[1]), @@ -175,7 +175,7 @@ class ActivationOpConverter : public TensorRTOpConverter { #endif }; auto it = op_map.find(op_name); - ICHECK(it != op_map.end()) << "Unsupported activation type " << op_name; + TVM_FFI_ICHECK(it != op_map.end()) << "Unsupported activation type " << op_name; nvinfer1::IActivationLayer* act_layer = params->network->addActivation(*params->inputs.at(0).tensor, it->second); #if TRT_VERSION_GE(5, 1, 5) @@ -189,7 +189,7 @@ class ActivationOpConverter : public TensorRTOpConverter { act_layer->setAlpha(alpha); } #endif - ICHECK(act_layer != nullptr); + TVM_FFI_ICHECK(act_layer != nullptr); params->outputs.push_back(act_layer->getOutput(0)); } }; @@ -210,7 +210,7 @@ class ElementWiseBinaryOpConverter : public TensorRTOpConverter { {"maximum", nvinfer1::ElementWiseOperation::kMAX}, {"minimum", nvinfer1::ElementWiseOperation::kMIN}}; auto it = op_map.find(op_name); - ICHECK(it != op_map.end()) << "Unsupported elementwise type " << op_name; + TVM_FFI_ICHECK(it != op_map.end()) << "Unsupported elementwise type " << op_name; // Broadcast auto input0 = params->inputs.at(0).tensor; auto input0_dims = TrtDimsToVector(input0->getDimensions()); @@ -231,7 +231,7 @@ class ElementWiseBinaryOpConverter : public TensorRTOpConverter { nvinfer1::IElementWiseLayer* elemwise_layer = params->network->addElementWise(*input0, *input1, it->second); - ICHECK(elemwise_layer != nullptr); + TVM_FFI_ICHECK(elemwise_layer != nullptr); params->outputs.push_back(elemwise_layer->getOutput(0)); } }; @@ -246,8 +246,8 @@ class Conv1DOpConverter : public TensorRTOpConverter { auto input_tensor = params->inputs.at(0).tensor; auto input_dims = TrtDimsToVector(input_tensor->getDimensions()); auto weight_shape = params->inputs.at(1).weight_shape; - ICHECK_EQ(params->node.GetAttr("data_layout"), "NCW"); - ICHECK_EQ(params->node.GetAttr("kernel_layout"), "OIW"); + TVM_FFI_ICHECK_EQ(params->node.GetAttr("data_layout"), "NCW"); + TVM_FFI_ICHECK_EQ(params->node.GetAttr("kernel_layout"), "OIW"); auto strides = params->node.GetAttr>("strides"); auto dilation = params->node.GetAttr>("dilation"); auto padding = params->node.GetAttr>("padding"); @@ -267,12 +267,12 @@ class Conv1DOpConverter : public TensorRTOpConverter { auto conv_layer = params->network->addConvolution(*input_tensor, channels, kernel_size, params->inputs.at(1).weight, bias); - ICHECK(conv_layer != nullptr); + TVM_FFI_ICHECK(conv_layer != nullptr); conv_layer->setPadding(nvinfer1::DimsHW(static_cast(padding[0]), 0)); - ICHECK_EQ(strides.size(), 1); + TVM_FFI_ICHECK_EQ(strides.size(), 1); const auto trt_strides = nvinfer1::DimsHW(static_cast(strides[0]), 1); conv_layer->setStride(trt_strides); - ICHECK_EQ(dilation.size(), 1); + TVM_FFI_ICHECK_EQ(dilation.size(), 1); const auto trt_dilation = nvinfer1::DimsHW(static_cast(dilation[0]), 1); conv_layer->setDilation(trt_dilation); conv_layer->setNbGroups(groups); @@ -296,10 +296,10 @@ class Conv2DOpConverter : public TensorRTOpConverter { auto input_tensor = params->inputs.at(0).tensor; auto input_dims = TrtDimsToVector(input_tensor->getDimensions()); auto weight_shape = params->inputs.at(1).weight_shape; - ICHECK_EQ(params->node.GetAttr("data_layout"), "NCHW"); - ICHECK(params->node.GetAttr("out_layout") == "" || - params->node.GetAttr("out_layout") == "NCHW"); - ICHECK_EQ(params->node.GetAttr("kernel_layout"), "OIHW"); + TVM_FFI_ICHECK_EQ(params->node.GetAttr("data_layout"), "NCHW"); + TVM_FFI_ICHECK(params->node.GetAttr("out_layout") == "" || + params->node.GetAttr("out_layout") == "NCHW"); + TVM_FFI_ICHECK_EQ(params->node.GetAttr("kernel_layout"), "OIHW"); auto strides = params->node.GetAttr>("strides"); auto dilation = params->node.GetAttr>("dilation"); auto padding = params->node.GetAttr>("padding"); @@ -314,7 +314,7 @@ class Conv2DOpConverter : public TensorRTOpConverter { #if !TRT_VERSION_GE(5, 1, 5) if (use_asymmetric_padding) { auto pad_layer = params->network->addPadding(*input_tensor, prepadding, postpadding); - ICHECK(pad_layer != nullptr); + TVM_FFI_ICHECK(pad_layer != nullptr); input_tensor = pad_layer->getOutput(0); // No need for conv op to do any padding. use_asymmetric_padding = false; @@ -327,7 +327,7 @@ class Conv2DOpConverter : public TensorRTOpConverter { nvinfer1::Weights bias{weight_type, nullptr, 0}; auto conv_layer = params->network->addConvolution(*input_tensor, channels, kernel_size, params->inputs.at(1).weight, bias); - ICHECK(conv_layer != nullptr); + TVM_FFI_ICHECK(conv_layer != nullptr); conv_layer->setName(params->LayerName().c_str()); if (use_asymmetric_padding) { #if TRT_VERSION_GE(5, 1, 5) @@ -337,11 +337,11 @@ class Conv2DOpConverter : public TensorRTOpConverter { } else { conv_layer->setPadding(prepadding); } - ICHECK_EQ(strides.size(), 2); + TVM_FFI_ICHECK_EQ(strides.size(), 2); const auto trt_strides = nvinfer1::DimsHW(static_cast(strides[0]), static_cast(strides[1])); conv_layer->setStride(trt_strides); - ICHECK_EQ(dilation.size(), 2); + TVM_FFI_ICHECK_EQ(dilation.size(), 2); const auto trt_dilation = nvinfer1::DimsHW(static_cast(dilation[0]), static_cast(dilation[1])); conv_layer->setDilation(trt_dilation); @@ -361,10 +361,10 @@ class Conv3DOpConverter : public TensorRTOpConverter { auto input_tensor = params->inputs.at(0).tensor; auto input_dims = TrtDimsToVector(input_tensor->getDimensions()); auto weight_shape = params->inputs.at(1).weight_shape; - ICHECK_EQ(params->node.GetAttr("data_layout"), "NCDHW"); - ICHECK(params->node.GetAttr("out_layout") == "" || - params->node.GetAttr("out_layout") == "NCDHW"); - ICHECK_EQ(params->node.GetAttr("kernel_layout"), "OIDHW"); + TVM_FFI_ICHECK_EQ(params->node.GetAttr("data_layout"), "NCDHW"); + TVM_FFI_ICHECK(params->node.GetAttr("out_layout") == "" || + params->node.GetAttr("out_layout") == "NCDHW"); + TVM_FFI_ICHECK_EQ(params->node.GetAttr("kernel_layout"), "OIDHW"); auto strides = params->node.GetAttr>("strides"); auto dilation = params->node.GetAttr>("dilation"); auto padding = params->node.GetAttr>("padding"); @@ -380,18 +380,18 @@ class Conv3DOpConverter : public TensorRTOpConverter { nvinfer1::Weights bias{weight_type, nullptr, 0}; auto conv_layer = params->network->addConvolutionNd(*input_tensor, num_outputs, kernel_size, params->inputs.at(1).weight, bias); - ICHECK(conv_layer != nullptr); + TVM_FFI_ICHECK(conv_layer != nullptr); if (use_asymmetric_padding) { conv_layer->setPrePadding(prepadding); conv_layer->setPostPadding(postpadding); } else { conv_layer->setPaddingNd(prepadding); } - ICHECK_EQ(strides.size(), 3); + TVM_FFI_ICHECK_EQ(strides.size(), 3); const auto trt_strides = nvinfer1::Dims3( static_cast(strides[0]), static_cast(strides[1]), static_cast(strides[2])); conv_layer->setStrideNd(trt_strides); - ICHECK_EQ(dilation.size(), 3); + TVM_FFI_ICHECK_EQ(dilation.size(), 3); const auto trt_dilation = nvinfer1::Dims3(static_cast(dilation[0]), static_cast(dilation[1]), static_cast(dilation[2])); @@ -411,7 +411,7 @@ class DenseOpConverter : public TensorRTOpConverter { void Convert(TensorRTOpConverterParams* params) const { auto input_tensor = params->inputs.at(0).tensor; auto input_dims = TrtDimsToVector(input_tensor->getDimensions()); - ICHECK(input_dims.size() > 0 && input_dims.size() <= 3); + TVM_FFI_ICHECK(input_dims.size() > 0 && input_dims.size() <= 3); const size_t required_rank = TRT_HAS_IMPLICIT_BATCH(params) ? 3 : 4; const bool need_reshape_on_input = input_dims.size() != required_rank; if (need_reshape_on_input) { @@ -421,13 +421,13 @@ class DenseOpConverter : public TensorRTOpConverter { input_tensor = Reshape(params, input_tensor, new_shape); } // Weights are in KC format. - ICHECK_EQ(params->inputs.at(1).weight_shape.size(), 2); + TVM_FFI_ICHECK_EQ(params->inputs.at(1).weight_shape.size(), 2); const int num_units = params->inputs.at(1).weight_shape[0]; const nvinfer1::DataType weight_type = params->inputs.at(1).weight.type; nvinfer1::Weights bias{weight_type, nullptr, 0}; nvinfer1::IFullyConnectedLayer* fc_layer = params->network->addFullyConnected( *input_tensor, num_units, params->inputs.at(1).weight, bias); - ICHECK(fc_layer != nullptr); + TVM_FFI_ICHECK(fc_layer != nullptr); auto output_tensor = fc_layer->getOutput(0); if (need_reshape_on_input) { // Remove added dims. @@ -450,9 +450,9 @@ class BatchNormOpConverter : public TensorRTOpConverter { auto beta = params->inputs.at(2).weight; auto mean = params->inputs.at(3).weight; auto var = params->inputs.at(4).weight; - ICHECK_EQ(gamma.count, beta.count); - ICHECK_EQ(gamma.count, mean.count); - ICHECK_EQ(gamma.count, var.count); + TVM_FFI_ICHECK_EQ(gamma.count, beta.count); + TVM_FFI_ICHECK_EQ(gamma.count, mean.count); + TVM_FFI_ICHECK_EQ(gamma.count, var.count); const float epsilon = static_cast(params->node.GetAttr("epsilon")); const int axis = static_cast(params->node.GetAttr("axis")); const bool scale = static_cast(params->node.GetAttr("scale")); @@ -460,7 +460,7 @@ class BatchNormOpConverter : public TensorRTOpConverter { auto input_dims = TrtDimsToVector(input->getDimensions()); const size_t min_rank = TRT_HAS_IMPLICIT_BATCH(params) ? 3 : 4; const size_t max_rank = TRT_HAS_IMPLICIT_BATCH(params) ? 4 : 5; - ICHECK_LE(input_dims.size(), max_rank); + TVM_FFI_ICHECK_LE(input_dims.size(), max_rank); const bool need_reshape = input_dims.size() < min_rank; const bool need_transpose = axis != 1; @@ -475,7 +475,7 @@ class BatchNormOpConverter : public TensorRTOpConverter { // Transpose if needed. const int input_rank_with_batch = input->getDimensions().nbDims + (TRT_HAS_IMPLICIT_BATCH(params) ? 1 : 0); - ICHECK(input_rank_with_batch == 4 || input_rank_with_batch == 5); + TVM_FFI_ICHECK(input_rank_with_batch == 4 || input_rank_with_batch == 5); std::vector transpose_order(input_rank_with_batch); if (need_transpose) { // Move axis dim to first dim after batch. @@ -521,11 +521,11 @@ class BatchNormOpConverter : public TensorRTOpConverter { nvinfer1::IScaleLayer* scale_layer = params->network->addScaleNd( *input, nvinfer1::ScaleMode::kCHANNEL, weight_shift, weight_scale, power, channel_dim); #else - ICHECK_EQ(input->getDimensions().nbDims, 3); + TVM_FFI_ICHECK_EQ(input->getDimensions().nbDims, 3); nvinfer1::IScaleLayer* scale_layer = params->network->addScale( *input, nvinfer1::ScaleMode::kCHANNEL, weight_shift, weight_scale, power); #endif - ICHECK(scale_layer != nullptr); + TVM_FFI_ICHECK(scale_layer != nullptr); auto output = scale_layer->getOutput(0); if (need_transpose) { output = Transpose(params, output, transpose_order); @@ -547,7 +547,7 @@ class LayerNormOpConverter : public TensorRTOpConverter { auto input = params->inputs.at(0).tensor; auto gamma_input = params->inputs.at(1).weight; auto beta_input = params->inputs.at(2).weight; - ICHECK_EQ(gamma_input.count, beta_input.count); + TVM_FFI_ICHECK_EQ(gamma_input.count, beta_input.count); const float epsilon = static_cast(params->node.GetAttr("epsilon")); const bool scale = static_cast(params->node.GetAttr("scale")); @@ -566,45 +566,45 @@ class LayerNormOpConverter : public TensorRTOpConverter { // Compute mean auto mean_layer = params->network->addReduce(*input, nvinfer1::ReduceOperation::kAVG, 1 << axis, /*keepdims=*/true); - ICHECK(mean_layer != nullptr); + TVM_FFI_ICHECK(mean_layer != nullptr); auto mean = mean_layer->getOutput(0); // Compute variance auto diff_layer = params->network->addElementWise(*input, *mean, nvinfer1::ElementWiseOperation::kSUB); - ICHECK(diff_layer != nullptr); + TVM_FFI_ICHECK(diff_layer != nullptr); auto square_layer = params->network->addElementWise(*diff_layer->getOutput(0), *diff_layer->getOutput(0), nvinfer1::ElementWiseOperation::kPROD); - ICHECK(square_layer != nullptr); + TVM_FFI_ICHECK(square_layer != nullptr); auto var_layer = params->network->addReduce( *square_layer->getOutput(0), nvinfer1::ReduceOperation::kAVG, 1 << axis, /*keepdims=*/true); - ICHECK(var_layer != nullptr); + TVM_FFI_ICHECK(var_layer != nullptr); auto var = var_layer->getOutput(0); // sqrt(var + epsilon) auto epsilon_tensor = CreateScalar(params, epsilon, var->getDimensions()); auto denom_add_layer = params->network->addElementWise(*var, *epsilon_tensor, nvinfer1::ElementWiseOperation::kSUM); - ICHECK(denom_add_layer != nullptr); + TVM_FFI_ICHECK(denom_add_layer != nullptr); auto denom_layer = params->network->addUnary(*denom_add_layer->getOutput(0), nvinfer1::UnaryOperation::kSQRT); - ICHECK(denom_layer != nullptr); + TVM_FFI_ICHECK(denom_layer != nullptr); // (input - mean) / sqrt(var + epsilon) auto output_layer = params->network->addElementWise(*diff_layer->getOutput(0), *denom_layer->getOutput(0), nvinfer1::ElementWiseOperation::kDIV); - ICHECK(output_layer != nullptr); + TVM_FFI_ICHECK(output_layer != nullptr); auto output = output_layer->getOutput(0); if (scale) { auto scale_layer = params->network->addElementWise(*output, *gamma, nvinfer1::ElementWiseOperation::kPROD); - ICHECK(scale_layer != nullptr); + TVM_FFI_ICHECK(scale_layer != nullptr); output = scale_layer->getOutput(0); } if (center) { auto center_layer = params->network->addElementWise(*output, *beta, nvinfer1::ElementWiseOperation::kSUM); - ICHECK(center_layer != nullptr); + TVM_FFI_ICHECK(center_layer != nullptr); output = center_layer->getOutput(0); } params->outputs.push_back(output); @@ -639,7 +639,7 @@ class SoftmaxOpConverter : public TensorRTOpConverter { const int axis = ConvertAxis(params, original_axis, input_rank); nvinfer1::ISoftMaxLayer* softmax_layer = params->network->addSoftMax(*input); softmax_layer->setAxes(1 << axis); - ICHECK(softmax_layer != nullptr); + TVM_FFI_ICHECK(softmax_layer != nullptr); params->outputs.push_back(softmax_layer->getOutput(0)); } }; @@ -656,8 +656,8 @@ class PoolingOpConverter : public TensorRTOpConverter { {"nn.max_pool2d", nvinfer1::PoolingType::kMAX}, {"nn.avg_pool2d", nvinfer1::PoolingType::kAVERAGE}}; auto it = op_map.find(op_name); - ICHECK(it != op_map.end()) << "Unsupported pooling type " << op_name << " in TensorRT"; - ICHECK_EQ(params->node.GetAttr("layout"), "NCHW"); + TVM_FFI_ICHECK(it != op_map.end()) << "Unsupported pooling type " << op_name << " in TensorRT"; + TVM_FFI_ICHECK_EQ(params->node.GetAttr("layout"), "NCHW"); auto pool_size = params->node.GetAttr>("pool_size"); auto padding = params->node.GetAttr>("padding"); auto strides = params->node.GetAttr>("strides"); @@ -671,7 +671,7 @@ class PoolingOpConverter : public TensorRTOpConverter { #if !TRT_VERSION_GE(5, 1, 5) if (use_asymmetric_padding) { auto pad_layer = params->network->addPadding(*input, prepadding, postpadding); - ICHECK(pad_layer != nullptr); + TVM_FFI_ICHECK(pad_layer != nullptr); input = pad_layer->getOutput(0); // No need for pooling op to do any padding. use_asymmetric_padding = false; @@ -682,7 +682,7 @@ class PoolingOpConverter : public TensorRTOpConverter { nvinfer1::DimsHW window_size = nvinfer1::DimsHW(static_cast(pool_size[0]), static_cast(pool_size[1])); auto pool_layer = params->network->addPooling(*input, it->second, window_size); - ICHECK(pool_layer != nullptr); + TVM_FFI_ICHECK(pool_layer != nullptr); nvinfer1::DimsHW trt_strides = nvinfer1::DimsHW(static_cast(strides[0]), static_cast(strides[1])); pool_layer->setStride(trt_strides); @@ -711,7 +711,7 @@ class PoolingOpConverter : public TensorRTOpConverter { pool_layer->setPaddingMode(nvinfer1::PaddingMode::kEXPLICIT_ROUND_UP); } #else - ICHECK(!ceil_mode); + TVM_FFI_ICHECK(!ceil_mode); #endif params->outputs.push_back(pool_layer->getOutput(0)); } @@ -730,8 +730,8 @@ class Pooling3DOpConverter : public TensorRTOpConverter { {"nn.max_pool3d", nvinfer1::PoolingType::kMAX}, {"nn.avg_pool3d", nvinfer1::PoolingType::kAVERAGE}}; auto it = op_map.find(op_name); - ICHECK(it != op_map.end()) << "Unsupported pooling type " << op_name << " in TensorRT"; - ICHECK_EQ(params->node.GetAttr("layout"), "NCDHW"); + TVM_FFI_ICHECK(it != op_map.end()) << "Unsupported pooling type " << op_name << " in TensorRT"; + TVM_FFI_ICHECK_EQ(params->node.GetAttr("layout"), "NCDHW"); auto pool_size = params->node.GetAttr>("pool_size"); auto padding = params->node.GetAttr>("padding"); auto strides = params->node.GetAttr>("strides"); @@ -743,7 +743,7 @@ class Pooling3DOpConverter : public TensorRTOpConverter { nvinfer1::Dims3(static_cast(pool_size[0]), static_cast(pool_size[1]), static_cast(pool_size[2])); auto pool_layer = params->network->addPoolingNd(*input, it->second, window_size); - ICHECK(pool_layer != nullptr); + TVM_FFI_ICHECK(pool_layer != nullptr); nvinfer1::Dims trt_strides = nvinfer1::Dims3( static_cast(strides[0]), static_cast(strides[1]), static_cast(strides[2])); pool_layer->setStrideNd(trt_strides); @@ -778,13 +778,13 @@ class GlobalPoolingOpConverter : public TensorRTOpConverter { {"nn.global_max_pool2d", nvinfer1::PoolingType::kMAX}, {"nn.global_avg_pool2d", nvinfer1::PoolingType::kAVERAGE}}; auto it = op_map.find(op_name); - ICHECK(it != op_map.end()) << "Unsupported pooling type " << op_name << " in TensorRT"; - ICHECK_EQ(params->node.GetAttr("layout"), "NCHW"); + TVM_FFI_ICHECK(it != op_map.end()) << "Unsupported pooling type " << op_name << " in TensorRT"; + TVM_FFI_ICHECK_EQ(params->node.GetAttr("layout"), "NCHW"); const int h = TRT_HAS_IMPLICIT_BATCH(params) ? input_dims[1] : input_dims[2]; const int w = TRT_HAS_IMPLICIT_BATCH(params) ? input_dims[2] : input_dims[3]; auto pool_layer = params->network->addPooling(*input_tensor, it->second, nvinfer1::DimsHW(h, w)); - ICHECK(pool_layer != nullptr); + TVM_FFI_ICHECK(pool_layer != nullptr); params->outputs.push_back(pool_layer->getOutput(0)); } }; @@ -854,10 +854,10 @@ class UnaryOpConverter : public TensorRTOpConverter { #endif }; auto it = op_map.find(op_name); - ICHECK(it != op_map.end()) << "Unsupported unary type " << op_name; + TVM_FFI_ICHECK(it != op_map.end()) << "Unsupported unary type " << op_name; nvinfer1::IUnaryLayer* unary_layer = params->network->addUnary(*params->inputs.at(0).tensor, it->second); - ICHECK(unary_layer != nullptr); + TVM_FFI_ICHECK(unary_layer != nullptr); params->outputs.push_back(unary_layer->getOutput(0)); } }; @@ -870,12 +870,12 @@ class ConcatOpConverter : public TensorRTOpConverter { void Convert(TensorRTOpConverterParams* params) const { const int num_inputs = params->inputs.size(); - ICHECK_GT(num_inputs, 0); + TVM_FFI_ICHECK_GT(num_inputs, 0); const int input_rank = params->inputs[0].tensor->getDimensions().nbDims; std::vector input_tensors; for (auto input : params->inputs) { - ICHECK_EQ(input.type, kTensor); - ICHECK_EQ(input_rank, input.tensor->getDimensions().nbDims); + TVM_FFI_ICHECK_EQ(input.type, kTensor); + TVM_FFI_ICHECK_EQ(input_rank, input.tensor->getDimensions().nbDims); input_tensors.push_back(input.tensor); } @@ -884,7 +884,7 @@ class ConcatOpConverter : public TensorRTOpConverter { nvinfer1::IConcatenationLayer* concat_layer = params->network->addConcatenation(input_tensors.data(), input_tensors.size()); - ICHECK(concat_layer != nullptr); + TVM_FFI_ICHECK(concat_layer != nullptr); concat_layer->setAxis(axis); params->outputs.push_back(concat_layer->getOutput(0)); } @@ -959,7 +959,7 @@ class BiasAddOpConverter : public TensorRTOpConverter { } else if (TRT_HAS_IMPLICIT_BATCH(params)) { axis -= 1; } - ICHECK(input_dims.size() > 0 && input_dims.size() <= required_rank); + TVM_FFI_ICHECK(input_dims.size() > 0 && input_dims.size() <= required_rank); const bool need_reshape_on_input = input_dims.size() != required_rank; if (need_reshape_on_input) { // Add dims of size 1 until rank is required_rank. @@ -975,7 +975,7 @@ class BiasAddOpConverter : public TensorRTOpConverter { nvinfer1::IScaleLayer* scale_layer = params->network->addScaleNd(*input_tensor, nvinfer1::ScaleMode::kCHANNEL, params->inputs.at(1).weight, scale, power, axis); - ICHECK(scale_layer != nullptr); + TVM_FFI_ICHECK(scale_layer != nullptr); auto output_tensor = scale_layer->getOutput(0); if (need_reshape_on_input) { // Remove added dims. @@ -994,12 +994,12 @@ class Conv2DTransposeOpConverter : public TensorRTOpConverter { void Convert(TensorRTOpConverterParams* params) const { auto input_tensor = params->inputs.at(0).tensor; auto weight_shape = params->inputs.at(1).weight_shape; - ICHECK_EQ(params->node.GetAttr("data_layout"), "NCHW"); - ICHECK(params->node.GetAttr("out_layout") == "" || - params->node.GetAttr("out_layout") == "NCHW"); - ICHECK_EQ(params->node.GetAttr("kernel_layout"), "OIHW"); + TVM_FFI_ICHECK_EQ(params->node.GetAttr("data_layout"), "NCHW"); + TVM_FFI_ICHECK(params->node.GetAttr("out_layout") == "" || + params->node.GetAttr("out_layout") == "NCHW"); + TVM_FFI_ICHECK_EQ(params->node.GetAttr("kernel_layout"), "OIHW"); auto dilation = params->node.GetAttr>("dilation"); - ICHECK(static_cast(dilation[0]) == 1 && static_cast(dilation[1]) == 1); + TVM_FFI_ICHECK(static_cast(dilation[0]) == 1 && static_cast(dilation[1]) == 1); auto strides = params->node.GetAttr>("strides"); auto padding = params->node.GetAttr>("padding"); auto output_padding = params->node.GetAttr>("output_padding"); @@ -1013,7 +1013,7 @@ class Conv2DTransposeOpConverter : public TensorRTOpConverter { #if !TRT_VERSION_GE(5, 1, 5) if (use_asymmetric_padding) { auto pad_layer = params->network->addPadding(*input_tensor, prepadding, postpadding); - ICHECK(pad_layer != nullptr); + TVM_FFI_ICHECK(pad_layer != nullptr); input_tensor = pad_layer->getOutput(0); // No need for conv op to do any padding. use_asymmetric_padding = false; @@ -1027,7 +1027,7 @@ class Conv2DTransposeOpConverter : public TensorRTOpConverter { nvinfer1::Weights bias{weight_type, nullptr, 0}; auto deconv_layer = params->network->addDeconvolution(*input_tensor, num_outputs, kernel_size, params->inputs.at(1).weight, bias); - ICHECK(deconv_layer != nullptr); + TVM_FFI_ICHECK(deconv_layer != nullptr); if (use_asymmetric_padding) { #if TRT_VERSION_GE(5, 1, 5) deconv_layer->setPrePadding(prepadding); @@ -1066,14 +1066,14 @@ class Conv3DTransposeOpConverter : public TensorRTOpConverter { void Convert(TensorRTOpConverterParams* params) const { auto input_tensor = params->inputs.at(0).tensor; auto weight_shape = params->inputs.at(1).weight_shape; - ICHECK_EQ(params->node.GetAttr("data_layout"), "NCDHW"); - ICHECK(params->node.GetAttr("out_layout") == "" || - params->node.GetAttr("out_layout") == "NCDHW"); - ICHECK_EQ(params->node.GetAttr("kernel_layout"), "OIDHW"); + TVM_FFI_ICHECK_EQ(params->node.GetAttr("data_layout"), "NCDHW"); + TVM_FFI_ICHECK(params->node.GetAttr("out_layout") == "" || + params->node.GetAttr("out_layout") == "NCDHW"); + TVM_FFI_ICHECK_EQ(params->node.GetAttr("kernel_layout"), "OIDHW"); auto dilation = params->node.GetAttr>("dilation"); - ICHECK_EQ(dilation.size(), 3); - ICHECK(static_cast(dilation[0]) == 1 && static_cast(dilation[1]) == 1 && - static_cast(dilation[2]) == 1); + TVM_FFI_ICHECK_EQ(dilation.size(), 3); + TVM_FFI_ICHECK(static_cast(dilation[0]) == 1 && static_cast(dilation[1]) == 1 && + static_cast(dilation[2]) == 1); auto strides = params->node.GetAttr>("strides"); auto padding = params->node.GetAttr>("padding"); auto output_padding = params->node.GetAttr>("output_padding"); @@ -1088,14 +1088,14 @@ class Conv3DTransposeOpConverter : public TensorRTOpConverter { nvinfer1::Weights bias{weight_type, nullptr, 0}; auto deconv_layer = params->network->addDeconvolutionNd(*input_tensor, num_outputs, kernel_size, params->inputs.at(1).weight, bias); - ICHECK(deconv_layer != nullptr); + TVM_FFI_ICHECK(deconv_layer != nullptr); if (use_asymmetric_padding) { deconv_layer->setPrePadding(prepadding); deconv_layer->setPostPadding(postpadding); } else { deconv_layer->setPaddingNd(prepadding); } - ICHECK_EQ(strides.size(), 3); + TVM_FFI_ICHECK_EQ(strides.size(), 3); const auto trt_strides = nvinfer1::Dims3( static_cast(strides[0]), static_cast(strides[1]), static_cast(strides[2])); deconv_layer->setStrideNd(trt_strides); @@ -1105,7 +1105,7 @@ class Conv3DTransposeOpConverter : public TensorRTOpConverter { if (output_padding.size()) { GetPadding3D(output_padding, &use_asymmetric_padding, &prepadding, &postpadding); // Are any post-padding values non-zero? - ICHECK(!std::any_of(postpadding.d, postpadding.d + postpadding.nbDims, [](int x) { + TVM_FFI_ICHECK(!std::any_of(postpadding.d, postpadding.d + postpadding.nbDims, [](int x) { return x != 0; })) << "TRT does not support padding on 3 dimensions."; } @@ -1170,7 +1170,7 @@ class ReshapeOpConverter : public TensorRTOpConverter { if (static_cast(newshape[0]) == -1) start_index = 0; for (size_t i = start_index; i < newshape.size(); ++i) { const int value = static_cast(newshape[i]); - ICHECK_GE(value, -1); + TVM_FFI_ICHECK_GE(value, -1); new_shape.push_back(value); } params->outputs.push_back(Reshape(params, input, new_shape)); @@ -1209,14 +1209,14 @@ class ReduceOpConverter : public TensorRTOpConverter { {"min", nvinfer1::ReduceOperation::kMIN}, {"mean", nvinfer1::ReduceOperation::kAVG}}; auto it = op_map.find(op_name); - ICHECK(it != op_map.end()) << "Unsupported reduce type " << op_name; + TVM_FFI_ICHECK(it != op_map.end()) << "Unsupported reduce type " << op_name; auto input = params->inputs.at(0).tensor; - ICHECK_EQ(static_cast(params->node.GetAttr("exclude")), false); + TVM_FFI_ICHECK_EQ(static_cast(params->node.GetAttr("exclude")), false); bool keepdims = static_cast(params->node.GetAttr("keepdims")); auto axes = params->node.GetAttr>("axis"); // TODO(trevmorr): Support reduce to scalar. - ICHECK_GT(axes.size(), 0); + TVM_FFI_ICHECK_GT(axes.size(), 0); uint32_t reduce_axes = 0; for (size_t i = 0; i < axes.size(); ++i) { @@ -1274,8 +1274,8 @@ class AdaptivePoolingOpConverter : public TensorRTOpConverter { {"nn.adaptive_max_pool2d", nvinfer1::PoolingType::kMAX}, {"nn.adaptive_avg_pool2d", nvinfer1::PoolingType::kAVERAGE}}; auto it = op_map.find(op_name); - ICHECK(it != op_map.end()) << "Unsupported pooling type " << op_name << " in TensorRT"; - ICHECK_EQ(params->node.GetAttr("layout"), "NCHW"); + TVM_FFI_ICHECK(it != op_map.end()) << "Unsupported pooling type " << op_name << " in TensorRT"; + TVM_FFI_ICHECK_EQ(params->node.GetAttr("layout"), "NCHW"); // This is an approximation of adaptive pooling. Results will not be // mathematically exact except when output_size is (1, 1). @@ -1287,7 +1287,7 @@ class AdaptivePoolingOpConverter : public TensorRTOpConverter { const auto window_size = nvinfer1::DimsHW(h - (output_size.h() - 1) * stride.h(), w - (output_size.w() - 1) * stride.w()); auto pool_layer = params->network->addPooling(*input_tensor, it->second, window_size); - ICHECK(pool_layer != nullptr); + TVM_FFI_ICHECK(pool_layer != nullptr); pool_layer->setStride(stride); params->outputs.push_back(pool_layer->getOutput(0)); } @@ -1308,7 +1308,7 @@ class BatchMatmulOpConverter : public TensorRTOpConverter { transb ? nvinfer1::MatrixOperation::kTRANSPOSE : nvinfer1::MatrixOperation::kNONE; nvinfer1::IMatrixMultiplyLayer* matmul_layer = params->network->addMatrixMultiply( *params->inputs.at(0).tensor, trt_transa, *params->inputs.at(1).tensor, trt_transb); - ICHECK(matmul_layer != nullptr); + TVM_FFI_ICHECK(matmul_layer != nullptr); params->outputs.push_back(matmul_layer->getOutput(0)); } }; diff --git a/src/runtime/contrib/tensorrt/tensorrt_runtime.cc b/src/runtime/contrib/tensorrt/tensorrt_runtime.cc index ac794bf8fa05..df8443dd590a 100644 --- a/src/runtime/contrib/tensorrt/tensorrt_runtime.cc +++ b/src/runtime/contrib/tensorrt/tensorrt_runtime.cc @@ -80,15 +80,15 @@ class TensorRTRuntime : public JSONRuntimeBase { multi_engine_mode_ = support::GetEnv("TVM_TENSORRT_MULTI_ENGINE", false); num_calibration_batches_remaining_ = support::GetEnv("TENSORRT_NUM_CALI_INT8", 0); if (use_int8) { - ICHECK(num_calibration_batches_remaining_ != 0) + TVM_FFI_ICHECK(num_calibration_batches_remaining_ != 0) << "When using INT8 mode, " << "environment variable TENSORRT_NUM_CALI_INT8" << "must also be set to specify the number of " << "calibration times"; LOG(INFO) << "settiing up " << num_calibration_batches_remaining_ << " sample data to calibrate data ... "; - ICHECK(multi_engine_mode_ == false) << "When using int8 mode, " - << "multi-engine is not allowed"; + TVM_FFI_ICHECK(multi_engine_mode_ == false) << "When using int8 mode, " + << "multi-engine is not allowed"; } } @@ -111,7 +111,7 @@ class TensorRTRuntime : public JSONRuntimeBase { * \param consts The constant params from compiled model. */ void Init(const ffi::Array& consts) override { - ICHECK_EQ(consts.size(), const_idx_.size()) + TVM_FFI_ICHECK_EQ(consts.size(), const_idx_.size()) << "The number of input constants must match the number of required."; LoadGlobalAttributes(); SetupConstants(consts); @@ -178,13 +178,13 @@ class TensorRTRuntime : public JSONRuntimeBase { uint32_t eid = EntryID(nid, j); const std::string name = nodes_[nid].GetOpName() + "_" + std::to_string(j); int binding_index = engine->getBindingIndex(name.c_str()); - ICHECK_NE(binding_index, -1); + TVM_FFI_ICHECK_NE(binding_index, -1); #if TRT_VERSION_GE(6, 0, 1) if (!use_implicit_batch_) { std::vector shape(data_entry_[eid]->shape, data_entry_[eid]->shape + data_entry_[eid]->ndim); auto dims = VectorToTrtDims(shape); - ICHECK(context->setBindingDimensions(binding_index, dims)); + TVM_FFI_ICHECK(context->setBindingDimensions(binding_index, dims)); } #endif if (data_entry_[eid]->device.device_type == kDLCUDA) { @@ -219,7 +219,7 @@ class TensorRTRuntime : public JSONRuntimeBase { uint32_t eid = EntryID(outputs_[i]); const std::string& name = engine_and_context.outputs[i]; int binding_index = engine->getBindingIndex(name.c_str()); - ICHECK_NE(binding_index, -1); + TVM_FFI_ICHECK_NE(binding_index, -1); if (data_entry_[eid]->device.device_type == kDLCUDA) { bindings[binding_index] = data_entry_[eid]->data; } else { @@ -230,12 +230,12 @@ class TensorRTRuntime : public JSONRuntimeBase { #if TRT_VERSION_GE(6, 0, 1) if (use_implicit_batch_) { - ICHECK(context->execute(batch_size, bindings.data())) << "Running TensorRT failed."; + TVM_FFI_ICHECK(context->execute(batch_size, bindings.data())) << "Running TensorRT failed."; } else { - ICHECK(context->executeV2(bindings.data())) << "Running TensorRT failed."; + TVM_FFI_ICHECK(context->executeV2(bindings.data())) << "Running TensorRT failed."; } #else - ICHECK(context->execute(batch_size, bindings.data())) << "Running TensorRT failed."; + TVM_FFI_ICHECK(context->execute(batch_size, bindings.data())) << "Running TensorRT failed."; #endif // Copy outputs from GPU buffers if needed. @@ -243,7 +243,7 @@ class TensorRTRuntime : public JSONRuntimeBase { uint32_t eid = EntryID(outputs_[i]); const std::string& name = engine_and_context.outputs[i]; int binding_index = engine->getBindingIndex(name.c_str()); - ICHECK_NE(binding_index, -1); + TVM_FFI_ICHECK_NE(binding_index, -1); if (data_entry_[eid]->device.device_type != kDLCUDA) { auto device_buffer = GetOrAllocateDeviceBuffer(eid, binding_index); device_buffer.CopyTo(const_cast(data_entry_[eid])); @@ -333,7 +333,7 @@ class TensorRTRuntime : public JSONRuntimeBase { if (node.GetOpType() == "input") { builder.AddInput(nid, EntryID(nid, 0), node); } else { - ICHECK_EQ(node.GetOpType(), "const"); + TVM_FFI_ICHECK_EQ(node.GetOpType(), "const"); uint32_t eid = EntryID(nid, 0); builder.AddConstant(nid, data_entry_[eid]); } @@ -504,8 +504,8 @@ class TensorRTRuntime : public JSONRuntimeBase { #else // TVM_GRAPH_EXECUTOR_TENSORRT void Run() override { - LOG(FATAL) << "TensorRT runtime is not enabled. " - << "Please build with USE_TENSORRT_RUNTIME."; + TVM_FFI_THROW(InternalError) << "TensorRT runtime is not enabled. " + << "Please build with USE_TENSORRT_RUNTIME."; } void BuildEngine() { diff --git a/src/runtime/contrib/tflite/tflite_runtime.cc b/src/runtime/contrib/tflite/tflite_runtime.cc index 9029a62f8da0..4c4d6c9cae2a 100644 --- a/src/runtime/contrib/tflite/tflite_runtime.cc +++ b/src/runtime/contrib/tflite/tflite_runtime.cc @@ -31,42 +31,42 @@ namespace tvm { namespace runtime { -#define TVM_DTYPE_DISPATCH(type, DType, ...) \ - if (type == DataType::Float(64)) { \ - typedef double DType; \ - { __VA_ARGS__ } \ - } else if (type == DataType::Float(32)) { \ - typedef float DType; \ - { __VA_ARGS__ } \ - } else if (type == DataType::Float(16)) { \ - typedef uint16_t DType; \ - { __VA_ARGS__ } \ - } else if (type == DataType::Int(64)) { \ - typedef int64_t DType; \ - { __VA_ARGS__ } \ - } else if (type == DataType::Int(32)) { \ - typedef int32_t DType; \ - { __VA_ARGS__ } \ - } else if (type == DataType::Int(16)) { \ - typedef int16_t DType; \ - { __VA_ARGS__ } \ - } else if (type == DataType::Int(8)) { \ - typedef int8_t DType; \ - { __VA_ARGS__ } \ - } else if (type == DataType::UInt(64)) { \ - typedef uint64_t DType; \ - { __VA_ARGS__ } \ - } else if (type == DataType::UInt(32)) { \ - typedef uint32_t DType; \ - { __VA_ARGS__ } \ - } else if (type == DataType::UInt(16)) { \ - typedef uint16_t DType; \ - { __VA_ARGS__ } \ - } else if (type == DataType::UInt(8)) { \ - typedef uint8_t DType; \ - { __VA_ARGS__ } \ - } else { \ - LOG(FATAL) << "unknown data type " << type; \ +#define TVM_DTYPE_DISPATCH(type, DType, ...) \ + if (type == DataType::Float(64)) { \ + typedef double DType; \ + { __VA_ARGS__ } \ + } else if (type == DataType::Float(32)) { \ + typedef float DType; \ + { __VA_ARGS__ } \ + } else if (type == DataType::Float(16)) { \ + typedef uint16_t DType; \ + { __VA_ARGS__ } \ + } else if (type == DataType::Int(64)) { \ + typedef int64_t DType; \ + { __VA_ARGS__ } \ + } else if (type == DataType::Int(32)) { \ + typedef int32_t DType; \ + { __VA_ARGS__ } \ + } else if (type == DataType::Int(16)) { \ + typedef int16_t DType; \ + { __VA_ARGS__ } \ + } else if (type == DataType::Int(8)) { \ + typedef int8_t DType; \ + { __VA_ARGS__ } \ + } else if (type == DataType::UInt(64)) { \ + typedef uint64_t DType; \ + { __VA_ARGS__ } \ + } else if (type == DataType::UInt(32)) { \ + typedef uint32_t DType; \ + { __VA_ARGS__ } \ + } else if (type == DataType::UInt(16)) { \ + typedef uint16_t DType; \ + { __VA_ARGS__ } \ + } else if (type == DataType::UInt(8)) { \ + typedef uint8_t DType; \ + { __VA_ARGS__ } \ + } else { \ + TVM_FFI_THROW(InternalError) << "unknown data type " << type; \ } DataType TfLiteDType2TVMDType(TfLiteType dtype) { @@ -86,7 +86,7 @@ DataType TfLiteDType2TVMDType(TfLiteType dtype) { case kTfLiteFloat16: return DataType::Float(16); default: - LOG(FATAL) << "tflite data type not support yet: " << dtype; + TVM_FFI_THROW(InternalError) << "tflite data type not support yet: " << dtype; TVM_FFI_UNREACHABLE(); } } @@ -118,7 +118,7 @@ void TFLiteRuntime::SetInput(int index, DLTensor* data_in) { TVM_DTYPE_DISPATCH(dtype, DType, { DType* dest = interpreter_->typed_input_tensor(index); DType* src = static_cast(data_in->data); - ICHECK(ffi::IsContiguous(*data_in)); + TVM_FFI_ICHECK(ffi::IsContiguous(*data_in)); int64_t size = 1; for (int64_t i = 0; i < data_in->ndim; ++i) { size *= data_in->shape[i]; @@ -158,7 +158,7 @@ ffi::Optional TFLiteRuntime::GetFunction(const ffi::String& name) if (name == "set_input") { return ffi::Function([sptr_to_self, this](ffi::PackedArgs args, ffi::Any* rv) { int in_idx = args[0].cast(); - ICHECK_GE(in_idx, 0); + TVM_FFI_ICHECK_GE(in_idx, 0); this->SetInput(in_idx, args[1].cast()); }); } else if (name == "get_output") { @@ -171,7 +171,7 @@ ffi::Optional TFLiteRuntime::GetFunction(const ffi::String& name) } else if (name == "set_num_threads") { return ffi::Function([sptr_to_self, this](ffi::PackedArgs args, ffi::Any* rv) { int num_threads = args[0].cast(); - CHECK_GE(num_threads, 1); + TVM_FFI_ICHECK_GE(num_threads, 1); this->SetNumThreads(num_threads); }); } else { diff --git a/src/runtime/contrib/tflite/tflite_runtime.h b/src/runtime/contrib/tflite/tflite_runtime.h index a5703ee70749..d231932a948b 100644 --- a/src/runtime/contrib/tflite/tflite_runtime.h +++ b/src/runtime/contrib/tflite/tflite_runtime.h @@ -38,7 +38,7 @@ namespace tvm { namespace runtime { -#define CHECK_TFLITE_STATUS(ret) ICHECK_EQ(ret, kTfLiteOk) +#define CHECK_TFLITE_STATUS(ret) TVM_FFI_ICHECK_EQ(ret, kTfLiteOk) /*! * \brief Tflite runtime. diff --git a/src/runtime/cuda/cuda_common.h b/src/runtime/cuda/cuda_common.h index fd032fc75bd1..7fe2e0d1672b 100644 --- a/src/runtime/cuda/cuda_common.h +++ b/src/runtime/cuda/cuda_common.h @@ -40,16 +40,18 @@ namespace runtime { if (result != CUDA_SUCCESS && result != CUDA_ERROR_DEINITIALIZED) { \ const char* msg; \ cuGetErrorName(result, &msg); \ - LOG(FATAL) << "CUDAError: " #x " failed with error: " << msg; \ + TVM_FFI_THROW(CUDAError) << "" #x " failed with error: " << msg; \ } \ } -#define CUDA_CALL(func) \ - { \ - cudaError_t e = (func); \ - ICHECK(e == cudaSuccess || e == cudaErrorCudartUnloading) \ - << "CUDA: " << cudaGetErrorString(e); \ +#ifndef CUDA_CALL +#define CUDA_CALL(func) \ + { \ + cudaError_t e = (func); \ + TVM_FFI_ICHECK(e == cudaSuccess || e == cudaErrorCudartUnloading) \ + << "CUDA: " << cudaGetErrorString(e); \ } +#endif /*! \brief Thread local workspace */ class CUDAThreadEntry { diff --git a/src/runtime/cuda/cuda_device_api.cc b/src/runtime/cuda/cuda_device_api.cc index 665db68b265e..26eafc86678a 100644 --- a/src/runtime/cuda/cuda_device_api.cc +++ b/src/runtime/cuda/cuda_device_api.cc @@ -134,7 +134,7 @@ class CUDADeviceAPI final : public DeviceAPI { *rv = value; } void* AllocDataSpace(Device dev, size_t nbytes, size_t alignment, DLDataType type_hint) final { - ICHECK_EQ(256 % alignment, 0U) << "CUDA space is aligned at 256 bytes"; + TVM_FFI_ICHECK_EQ(256 % alignment, 0U) << "CUDA space is aligned at 256 bytes"; void* ret; if (dev.device_type == kDLCUDAHost) { VLOG(1) << "allocating " << nbytes << "bytes on host"; @@ -215,7 +215,7 @@ class CUDADeviceAPI final : public DeviceAPI { CUDA_CALL(cudaSetDevice(dev_to.device_id)); GPUCopy(from, to, size, cudaMemcpyHostToDevice, cu_stream); } else { - LOG(FATAL) << "expect copy from/to GPU or between GPU"; + TVM_FFI_THROW(InternalError) << "expect copy from/to GPU or between GPU"; } } @@ -398,14 +398,14 @@ TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def_packed("runtime.cuTensorMapEncodeTiled", [](ffi::PackedArgs args, ffi::Any* rv) { - CHECK_GE(args.size(), 4) << "init_cuTensorMap expects at least 4 arguments"; + TVM_FFI_ICHECK_GE(args.size(), 4) << "init_cuTensorMap expects at least 4 arguments"; size_t arg_cnt = 0; CUtensorMap* tensor_map = static_cast(args[arg_cnt++].cast()); runtime::DataType tensor_dtype = args[arg_cnt++].cast(); uint32_t tensor_rank = static_cast(args[arg_cnt++].cast()); void* tensor_ptr = static_cast(args[arg_cnt++].cast()); - CHECK_EQ(args.size(), 4 + tensor_rank * 4 + 3) + TVM_FFI_ICHECK_EQ(args.size(), 4 + tensor_rank * 4 + 3) << "cuTensorMapEncodeTiled expects " << 4 + tensor_rank * 4 + 3 << " arguments" << "tensor_map, tensor_dtype, tensor_rank, tensor_ptr, global_shape(" << tensor_rank << "), global_strides(" << tensor_rank - 1 << "), shared_shape(" << tensor_rank @@ -421,12 +421,12 @@ TVM_FFI_STATIC_INIT_BLOCK() { } for (size_t i = 0; i < tensor_rank - 1; ++i) { global_strides[i] = static_cast(args[arg_cnt++].cast()); - CHECK_EQ(global_strides[i] % 16, 0) << "global strides must be multiple of 16"; + TVM_FFI_ICHECK_EQ(global_strides[i] % 16, 0) << "global strides must be multiple of 16"; } for (size_t i = 0; i < tensor_rank; ++i) { shared_shape[i] = static_cast(args[arg_cnt++].cast()); - CHECK_GE(shared_shape[i], 0) << "boxDim must be non-negative"; - CHECK_LE(shared_shape[i], 256) << "boxDim must be less than or equal to 256"; + TVM_FFI_ICHECK_GE(shared_shape[i], 0) << "boxDim must be non-negative"; + TVM_FFI_ICHECK_LE(shared_shape[i], 256) << "boxDim must be less than or equal to 256"; } for (size_t i = 0; i < tensor_rank; ++i) { shared_strides[i] = static_cast(args[arg_cnt++].cast()); @@ -436,7 +436,7 @@ TVM_FFI_STATIC_INIT_BLOCK() { auto l2_promotion_kind = static_cast(args[arg_cnt++].cast()); auto oob_fill_kind = static_cast(args[arg_cnt++].cast()); - ICHECK_EQ(tensor_dtype.lanes(), 1) + TVM_FFI_ICHECK_EQ(tensor_dtype.lanes(), 1) << "Expect tensor_dtype to have lanes=1, but get " << tensor_dtype; CUtensorMapDataType cu_dtype; switch (tensor_dtype.code()) { @@ -453,7 +453,8 @@ TVM_FFI_STATIC_INIT_BLOCK() { cu_dtype = CU_TENSOR_MAP_DATA_TYPE_INT64; break; default: - LOG(FATAL) << "Unsupported data type " << runtime::DLDataTypeToString(tensor_dtype); + TVM_FFI_THROW(InternalError) + << "Unsupported data type " << runtime::DLDataTypeToString(tensor_dtype); } break; case DataType::kUInt: @@ -472,7 +473,8 @@ TVM_FFI_STATIC_INIT_BLOCK() { cu_dtype = CU_TENSOR_MAP_DATA_TYPE_UINT64; break; default: - LOG(FATAL) << "Unsupported data type " << runtime::DLDataTypeToString(tensor_dtype); + TVM_FFI_THROW(InternalError) + << "Unsupported data type " << runtime::DLDataTypeToString(tensor_dtype); } break; case DataType::kFloat: @@ -488,7 +490,8 @@ TVM_FFI_STATIC_INIT_BLOCK() { cu_dtype = CU_TENSOR_MAP_DATA_TYPE_FLOAT64; break; default: - LOG(FATAL) << "Unsupported data type " << runtime::DLDataTypeToString(tensor_dtype); + TVM_FFI_THROW(InternalError) + << "Unsupported data type " << runtime::DLDataTypeToString(tensor_dtype); } break; case DataType::kBFloat: @@ -498,7 +501,8 @@ TVM_FFI_STATIC_INIT_BLOCK() { cu_dtype = CU_TENSOR_MAP_DATA_TYPE_BFLOAT16; break; default: - LOG(FATAL) << "Unsupported data type " << runtime::DLDataTypeToString(tensor_dtype); + TVM_FFI_THROW(InternalError) + << "Unsupported data type " << runtime::DLDataTypeToString(tensor_dtype); } break; case DataType::kFloat8_e4m3fn: @@ -510,24 +514,25 @@ TVM_FFI_STATIC_INIT_BLOCK() { cu_dtype = CU_TENSOR_MAP_DATA_TYPE_UINT8; break; default: - LOG(FATAL) << "Unsupported data type " << runtime::DLDataTypeToString(tensor_dtype); + TVM_FFI_THROW(InternalError) + << "Unsupported data type " << runtime::DLDataTypeToString(tensor_dtype); } // sanity checks per cuTensorMapEncodeTiled requirements // see // https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__TENSOR__MEMORY.html#group__CUDA__TENSOR__MEMORY_1ga7c7d2aaac9e49294304e755e6f341d7 - CHECK_EQ((reinterpret_cast(tensor_ptr) & 0b1111), 0); // 16-byte alignment - CHECK_EQ((reinterpret_cast(tensor_map) & 0b111111), 0); // 64-byte alignment - CHECK_LE(tensor_rank, 5) << "cuTensorMapEncodeTiled only supports up to 5D tensors"; + TVM_FFI_ICHECK_EQ((reinterpret_cast(tensor_ptr) & 0b1111), 0); // 16-byte alignment + TVM_FFI_ICHECK_EQ((reinterpret_cast(tensor_map) & 0b111111), 0); // 64-byte alignment + TVM_FFI_ICHECK_LE(tensor_rank, 5) << "cuTensorMapEncodeTiled only supports up to 5D tensors"; if (swizzle_kind == CU_TENSOR_MAP_SWIZZLE_32B) { - CHECK_LE(shared_shape[0] * tensor_dtype.bytes(), 32) + TVM_FFI_ICHECK_LE(shared_shape[0] * tensor_dtype.bytes(), 32) << "CU_TENSOR_MAP_SWIZZLE_32B implies the bounding box inner dimension will be <= 32."; } else if (swizzle_kind == CU_TENSOR_MAP_SWIZZLE_64B) { - CHECK_LE(shared_shape[0] * tensor_dtype.bytes(), 64) + TVM_FFI_ICHECK_LE(shared_shape[0] * tensor_dtype.bytes(), 64) << "CU_TENSOR_MAP_SWIZZLE_64B implies the bounding box inner dimension will be <= 64."; } else if (swizzle_kind == CU_TENSOR_MAP_SWIZZLE_128B) { - CHECK_LE(shared_shape[0] * tensor_dtype.bytes(), 128) + TVM_FFI_ICHECK_LE(shared_shape[0] * tensor_dtype.bytes(), 128) << "CU_TENSOR_MAP_SWIZZLE_128B implies the bounding box inner dimension will be <= " "128."; } @@ -575,7 +580,7 @@ TVM_FFI_STATIC_INIT_BLOCK() { std::cout << shared_strides[i] << " "; } std::cout << "\n"; - CHECK_EQ(res, CUDA_SUCCESS) << "Error in cuTensorMapEncodeTiled: " << errstr; + TVM_FFI_ICHECK_EQ(res, CUDA_SUCCESS) << "Error in cuTensorMapEncodeTiled: " << errstr; } }); } diff --git a/src/runtime/cuda/cuda_module.cc b/src/runtime/cuda/cuda_module.cc index 92fc29d4236f..d529af126bdb 100644 --- a/src/runtime/cuda/cuda_module.cc +++ b/src/runtime/cuda/cuda_module.cc @@ -83,11 +83,11 @@ class CUDAModuleNode : public ffi::ModuleObj { std::string fmt = GetFileFormat(file_name, format); std::string meta_file = GetMetaFilePath(file_name); if (fmt == "cu") { - ICHECK_NE(cuda_source_.length(), 0); + TVM_FFI_ICHECK_NE(cuda_source_.length(), 0); SaveMetaDataToFile(meta_file, fmap_); SaveBinaryToFile(file_name, cuda_source_); } else { - ICHECK_EQ(fmt, fmt_) << "Can only save to format=" << fmt_; + TVM_FFI_ICHECK_EQ(fmt, fmt_) << "Can only save to format=" << fmt_; SaveMetaDataToFile(meta_file, fmap_); SaveBinaryToFile(file_name, data_); } @@ -128,7 +128,8 @@ class CUDAModuleNode : public ffi::ModuleObj { if (result != CUDA_SUCCESS) { const char* msg; cuGetErrorName(result, &msg); - LOG(FATAL) << "CUDAError: cuModuleGetFunction " << func_name << " failed with error: " << msg; + TVM_FFI_THROW(CUDAError) << "cuModuleGetFunction " << func_name + << " failed with error: " << msg; } return func; } @@ -147,11 +148,12 @@ class CUDAModuleNode : public ffi::ModuleObj { size_t nbytes; CUresult result = cuModuleGetGlobal(&global, &nbytes, module_[device_id], global_name.c_str()); - ICHECK_EQ(nbytes, expect_nbytes); + TVM_FFI_ICHECK_EQ(nbytes, expect_nbytes); if (result != CUDA_SUCCESS) { const char* msg; cuGetErrorName(result, &msg); - LOG(FATAL) << "CUDAError: cuModuleGetGlobal " << global_name << " failed with error: " << msg; + TVM_FFI_THROW(CUDAError) << "cuModuleGetGlobal " << global_name + << " failed with error: " << msg; } return global; } @@ -197,8 +199,8 @@ class CUDAWrappedFunc { CUresult result = cuFuncSetAttribute( fcache_[device_id], CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, wl.dyn_shmem_size); if (result != CUDA_SUCCESS) { - LOG(FATAL) << "Failed to set the allowed dynamic shared memory size to " - << wl.dyn_shmem_size; + TVM_FFI_THROW(InternalError) + << "Failed to set the allowed dynamic shared memory size to " << wl.dyn_shmem_size; } } } @@ -248,7 +250,7 @@ class CUDAWrappedFunc { << "// -----------\n" << cuda; } - LOG(FATAL) << os.str(); + TVM_FFI_THROW(InternalError) << os.str(); } } @@ -293,7 +295,7 @@ class CUDAPrepGlobalBarrier { ffi::Optional CUDAModuleNode::GetFunction(const ffi::String& name) { ObjectPtr sptr_to_self = ffi::GetObjectPtr(this); - ICHECK_EQ(sptr_to_self.get(), this); + TVM_FFI_ICHECK_EQ(sptr_to_self.get(), this); if (name == symbol::tvm_prepare_global_barrier) { return ffi::Function(CUDAPrepGlobalBarrier(this, sptr_to_self)); } @@ -328,7 +330,7 @@ ffi::Module CUDAModuleLoadFromBytes(const ffi::Bytes& bytes) { ffi::Map fmap; std::string fmt; stream.Read(&fmt); - ICHECK(stream.Read(&fmap)); + TVM_FFI_ICHECK(stream.Read(&fmap)); stream.Read(&data); return CUDAModuleCreate(data, fmt, fmap, std::string()); } diff --git a/src/runtime/cuda/l2_cache_flush.cc b/src/runtime/cuda/l2_cache_flush.cc index 90003df50693..5a1d4da0e70a 100644 --- a/src/runtime/cuda/l2_cache_flush.cc +++ b/src/runtime/cuda/l2_cache_flush.cc @@ -37,7 +37,7 @@ L2Flush* L2Flush::ThreadLocal() { TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def_packed("l2_cache_flush_cuda", [](ffi::PackedArgs args, ffi::Any* rv) { - ICHECK(L2Flush::ThreadLocal() != nullptr) << "L2Flush::ThreadLocal do not exist."; + TVM_FFI_ICHECK(L2Flush::ThreadLocal() != nullptr) << "L2Flush::ThreadLocal do not exist."; int device_id; CUDA_CALL(cudaGetDevice(&device_id)); cudaStream_t stream = static_cast(TVMFFIEnvGetStream(kDLCUDA, device_id)); diff --git a/src/runtime/device_api.cc b/src/runtime/device_api.cc index d2d553f138d0..f95262424d31 100644 --- a/src/runtime/device_api.cc +++ b/src/runtime/device_api.cc @@ -86,7 +86,7 @@ class DeviceAPIManager { std::string factory = "device_api." + name; const auto f = tvm::ffi::Function::GetGlobal(factory); if (!f.has_value()) { - ICHECK(allow_missing) << "Device API " << name << " is not enabled."; + TVM_FFI_ICHECK(allow_missing) << "Device API " << name << " is not enabled."; return nullptr; } void* ptr = (*f)().cast(); @@ -116,8 +116,8 @@ size_t DeviceAPI::GetDataSize(const DLTensor& arr, ffi::Optional me } return ffi::GetDataSize(size, arr.dtype); } - LOG(FATAL) << "Device does not support physical mem computation with " - << "specified memory scope: " << mem_scope.value(); + TVM_FFI_THROW(InternalError) << "Device does not support physical mem computation with " + << "specified memory scope: " << mem_scope.value(); return 0; } @@ -137,17 +137,17 @@ void* DeviceAPI::AllocDataSpace(Device dev, int ndim, const int64_t* shape, DLDa size_t alignment = GetDataAlignment(temp.dtype); return AllocDataSpace(dev, size, alignment, dtype); } - LOG(FATAL) << "Device does not support allocate data space with " - << "specified memory scope: " << mem_scope.value(); + TVM_FFI_THROW(InternalError) << "Device does not support allocate data space with " + << "specified memory scope: " << mem_scope.value(); return nullptr; } void DeviceAPI::CopyDataFromTo(DLTensor* from, DLTensor* to, TVMStreamHandle stream) { // by default, we can always redirect to the flat memory copy operation. size_t nbytes = GetDataSize(*from); - ICHECK_EQ(nbytes, GetDataSize(*to)); + TVM_FFI_ICHECK_EQ(nbytes, GetDataSize(*to)); - ICHECK(ffi::IsContiguous(*from) && ffi::IsContiguous(*to)) + TVM_FFI_ICHECK(ffi::IsContiguous(*from) && ffi::IsContiguous(*to)) << "CopyDataFromTo only support contiguous array for now"; CopyDataFromTo(from->data, from->byte_offset, to->data, to->byte_offset, nbytes, from->device, to->device, from->dtype, stream); @@ -156,7 +156,7 @@ void DeviceAPI::CopyDataFromTo(DLTensor* from, DLTensor* to, TVMStreamHandle str void DeviceAPI::CopyDataFromTo(const void* from, size_t from_offset, void* to, size_t to_offset, size_t num_bytes, Device dev_from, Device dev_to, DLDataType type_hint, TVMStreamHandle stream) { - LOG(FATAL) << "Device does not support CopyDataFromTo."; + TVM_FFI_THROW(InternalError) << "Device does not support CopyDataFromTo."; } void DeviceAPI::FreeWorkspace(Device dev, void* ptr) { FreeDataSpace(dev, ptr); } diff --git a/src/runtime/disco/bcast_session.cc b/src/runtime/disco/bcast_session.cc index 2ea9ef575d05..8828bdfae821 100644 --- a/src/runtime/disco/bcast_session.cc +++ b/src/runtime/disco/bcast_session.cc @@ -69,19 +69,20 @@ void BcastSessionObj::Shutdown() { void BcastSessionObj::InitCCL(ffi::String ccl, ffi::Shape device_ids) { const auto pf = tvm::ffi::Function::GetGlobal("runtime.disco." + ccl + ".init_ccl"); - CHECK(pf.has_value()) << "ValueError: Cannot initialize CCL `" << ccl - << "`, because cannot find function: runtime.disco." << ccl << ".init_ccl"; + TVM_FFI_CHECK(pf.has_value(), ValueError) + << "Cannot initialize CCL `" << ccl << "`, because cannot find function: runtime.disco." + << ccl << ".init_ccl"; (*pf)(ffi::GetRef(this), device_ids); } void BcastSessionObj::SyncWorker(int worker_id) { BcastSessionObj::Internal::BroadcastUnpacked(this, DiscoAction::kSyncWorker, worker_id); ffi::PackedArgs args = this->RecvReplyPacked(worker_id); - ICHECK_EQ(args.size(), 2); + TVM_FFI_ICHECK_EQ(args.size(), 2); DiscoAction action = static_cast(args[0].cast()); int ret_worker_id = args[1].cast(); - ICHECK(action == DiscoAction::kSyncWorker); - ICHECK_EQ(ret_worker_id, worker_id); + TVM_FFI_ICHECK(action == DiscoAction::kSyncWorker); + TVM_FFI_ICHECK_EQ(ret_worker_id, worker_id); } DRef BcastSessionObj::CallWithPacked(const ffi::PackedArgs& args) { diff --git a/src/runtime/disco/builtin.cc b/src/runtime/disco/builtin.cc index 8584d15c5e04..57e1de17de31 100644 --- a/src/runtime/disco/builtin.cc +++ b/src/runtime/disco/builtin.cc @@ -61,8 +61,9 @@ ffi::Module LoadVMModule(std::string path, ffi::Optional device) { auto mod = (*vm_load_executable)().cast(); ffi::Optional vm_initialization = mod->GetFunction("vm_initialization"); if (!vm_initialization.has_value()) { - LOG(FATAL) << "ValueError: File `" << path - << "` is not built by RelaxVM, because `vm_initialization` does not exist"; + TVM_FFI_THROW(ValueError) + << "File `" << path + << "` is not built by RelaxVM, because `vm_initialization` does not exist"; } (*vm_initialization)(static_cast(dev.device_type), static_cast(dev.device_id), static_cast(AllocatorType::kPooled), static_cast(kDLCPU), 0, @@ -78,8 +79,8 @@ ffi::Function GetCCLFunc(const char* name) { std::string ccl = DiscoWorker::ThreadLocal()->ccl; std::string pf_name = "runtime.disco." + ccl + "." + name; const auto pf = tvm::ffi::Function::GetGlobal(pf_name); - CHECK(pf.has_value()) << "ValueError: Cannot find the `" << name << "` function for `" << ccl - << "` via `" << pf_name << "`"; + TVM_FFI_CHECK(pf.has_value(), ValueError) + << "Cannot find the `" << name << "` function for `" << ccl << "` via `" << pf_name << "`"; return *pf; } @@ -146,7 +147,7 @@ TVM_FFI_STATIC_INIT_BLOCK() { .def("runtime.disco.allreduce", [](Tensor send, ffi::Shape reduce_kind, bool in_group, Tensor recv) { int kind = IntegerFromShape(reduce_kind); - CHECK(0 <= kind && kind <= 4) << "ValueError: Unknown ReduceKind: " << kind; + TVM_FFI_CHECK(0 <= kind && kind <= 4, ValueError) << "Unknown ReduceKind: " << kind; AllReduce(send, static_cast(kind), in_group, recv); }) .def("runtime.disco.allgather", AllGather) @@ -164,7 +165,7 @@ TVM_FFI_STATIC_INIT_BLOCK() { []() -> Device { return DiscoWorker::ThreadLocal()->default_device; }) .def("runtime.disco.bind_worker_to_cpu_core", [](ffi::Shape cpu_ids) { int worker_id = WorkerId(); - ICHECK_LT(worker_id, static_cast(cpu_ids.size())); + TVM_FFI_ICHECK_LT(worker_id, static_cast(cpu_ids.size())); const auto f_set_thread_affinity = tvm::ffi::Function::GetGlobalRequired( "tvm.runtime.threading.set_current_thread_affinity"); f_set_thread_affinity(ffi::Shape{cpu_ids[worker_id]}); diff --git a/src/runtime/disco/cuda_ipc/cuda_ipc_memory.cc b/src/runtime/disco/cuda_ipc/cuda_ipc_memory.cc index 7dc55c0b4b7c..23c9613e5e6d 100644 --- a/src/runtime/disco/cuda_ipc/cuda_ipc_memory.cc +++ b/src/runtime/disco/cuda_ipc/cuda_ipc_memory.cc @@ -77,7 +77,7 @@ class CUDAIPCMemoryAllocator final : public memory::PooledAllocator { CUDAIPCMemory GetIPCMemoryFromDevicePtr(void* ptr) const { auto it = ipc_memory_map_.find(ptr); - CHECK(it != ipc_memory_map_.end()) + TVM_FFI_ICHECK(it != ipc_memory_map_.end()) << "The given pointer's CUDAIPCMemory object does not exist. Please use global function " "\"cuda_ipc.alloc_storage\" to allocate the CUDAIPCMemory object first."; return it->second; @@ -114,11 +114,11 @@ class CUDAIPCMemoryAllocator final : public memory::PooledAllocator { } void DeviceFreeDataSpace(Device dev, void* ptr) final { - ICHECK(dev.device_type == kDLCUDA); + TVM_FFI_ICHECK(dev.device_type == kDLCUDA); CUDA_CALL(cudaSetDevice(dev.device_id)); nccl::CCLThreadLocalContext* ctx = nccl::CCLThreadLocalContext::Get(); auto it = ipc_memory_map_.find(ptr); - ICHECK(it != ipc_memory_map_.end()); + TVM_FFI_ICHECK(it != ipc_memory_map_.end()); FreeIPCMemory(it->second->remote_data, ctx->worker->worker_id); FreeIPCMemory(it->second->barrier_in, ctx->worker->worker_id); FreeIPCMemory(it->second->barrier_out, ctx->worker->worker_id); @@ -144,7 +144,7 @@ class CUDAIPCMemoryAllocator final : public memory::PooledAllocator { DLDataType type_hint, bool reset_memory_to_zero) { // Alloc local buffer - ICHECK(dev.device_type == kDLCUDA); + TVM_FFI_ICHECK(dev.device_type == kDLCUDA); void* ptr; CUDA_CALL(cudaSetDevice(dev.device_id)); CUDA_CALL(cudaMalloc(&ptr, size)); diff --git a/src/runtime/disco/cuda_ipc/custom_allreduce.cc b/src/runtime/disco/cuda_ipc/custom_allreduce.cc index 060a098a9d63..69652c7e82c9 100644 --- a/src/runtime/disco/cuda_ipc/custom_allreduce.cc +++ b/src/runtime/disco/cuda_ipc/custom_allreduce.cc @@ -38,7 +38,7 @@ inline int64_t TensorSize(const DLTensor* tensor) { int64_t size = 1; for (int i = tensor->ndim - 1; i >= 0; --i) { if (tensor->strides) { - ICHECK_EQ(tensor->strides[i], size); + TVM_FFI_ICHECK_EQ(tensor->strides[i], size); } size *= tensor->shape[i]; } @@ -66,7 +66,7 @@ inline bool CanApplyTwoShotAllReduce(int64_t num_elements, DLDataType dtype, int void CustomAllReduce(DLTensor* send, int strategy, DLTensor* recv) { int64_t num_elements = TensorSize(send); nccl::CCLThreadLocalContext* ctx = nccl::CCLThreadLocalContext::Get(); - CHECK_EQ(ctx->worker->num_groups, 1) + TVM_FFI_ICHECK_EQ(ctx->worker->num_groups, 1) << "Custom AllReduce for multiple group is not yet implemented."; tensorrt_llm::AllReduceStrategyType strategy_ = diff --git a/src/runtime/disco/disco_worker.cc b/src/runtime/disco/disco_worker.cc index d9865ca2bec4..a7ff1a5a035a 100644 --- a/src/runtime/disco/disco_worker.cc +++ b/src/runtime/disco/disco_worker.cc @@ -29,12 +29,12 @@ namespace runtime { TVM_DLL DiscoWorker* DiscoWorker::ThreadLocal() { DiscoWorker* ret = ThreadLocalDiscoWorker::Get()->worker; - CHECK(ret) << "ValueError: The current thread is not a DiscoWorker thread"; + TVM_FFI_CHECK(ret, ValueError) << "The current thread is not a DiscoWorker thread"; return ret; } void DiscoWorker::SetRegister(int reg_id, ffi::AnyView value) { - ICHECK(0 <= reg_id && reg_id < static_cast(register_file.size())); + TVM_FFI_ICHECK(0 <= reg_id && reg_id < static_cast(register_file.size())); ffi::Any& rv = register_file.at(reg_id); if (rv.type_index() == ffi::TypeIndex::kTVMFFITensor && value.type_index() == ffi::TypeIndex::kTVMFFITensor) { @@ -69,9 +69,9 @@ struct DiscoWorker::Impl { } case DiscoAction::kCallPacked: { int func_reg_id = args[2].cast(); - CHECK_LT(func_reg_id, self->register_file.size()); + TVM_FFI_ICHECK_LT(func_reg_id, self->register_file.size()); ffi::Function func = GetReg(self, func_reg_id).cast(); - CHECK(func.defined()); + TVM_FFI_ICHECK(func.defined()); CallPacked(self, reg_id, func, args.Slice(3)); break; } @@ -106,7 +106,7 @@ struct DiscoWorker::Impl { static void GetGlobalFunc(DiscoWorker* self, int reg_id, const std::string& name) { const auto pf = tvm::ffi::Function::GetGlobal(name); - CHECK(pf.has_value()) << "ValueError: Cannot find global function: " << name; + TVM_FFI_CHECK(pf.has_value(), ValueError) << "Cannot find global function: " << name; if (reg_id != 0) { GetReg(self, reg_id) = *pf; } diff --git a/src/runtime/disco/distributed/socket_session.cc b/src/runtime/disco/distributed/socket_session.cc index 99c54933bf3a..def2522b60d3 100644 --- a/src/runtime/disco/distributed/socket_session.cc +++ b/src/runtime/disco/distributed/socket_session.cc @@ -60,7 +60,7 @@ class SocketSessionObj : public BcastSessionObj { : num_nodes_(num_nodes), num_workers_per_node_(num_workers_per_node) { const auto f_create_local_session = tvm::ffi::Function::GetGlobal("runtime.disco.create_socket_session_local_workers"); - ICHECK(f_create_local_session.has_value()) + TVM_FFI_ICHECK(f_create_local_session.has_value()) << "Cannot find function runtime.disco.create_socket_session_local_workers"; local_session_ = ((*f_create_local_session)(num_workers_per_node)).cast(); DRef f_init_workers = @@ -107,8 +107,9 @@ class SocketSessionObj : public BcastSessionObj { static_cast(DiscoAction::kDebugGetFromRemote), reg_id, worker_id); remote_channels_[node_id - 1]->Send(ffi::PackedArgs(packed_args, 5)); ffi::PackedArgs args = this->RecvReplyPacked(worker_id); - ICHECK_EQ(args.size(), 2); - ICHECK(static_cast(args[0].cast()) == DiscoAction::kDebugGetFromRemote); + TVM_FFI_ICHECK_EQ(args.size(), 2); + TVM_FFI_ICHECK(static_cast(args[0].cast()) == + DiscoAction::kDebugGetFromRemote); ffi::Any result; result = args[1]; return result; @@ -134,8 +135,9 @@ class SocketSessionObj : public BcastSessionObj { } ffi::Any result; ffi::PackedArgs args = this->RecvReplyPacked(worker_id); - ICHECK_EQ(args.size(), 1); - ICHECK(static_cast(args[0].cast()) == DiscoAction::kDebugSetRegister); + TVM_FFI_ICHECK_EQ(args.size(), 1); + TVM_FFI_ICHECK(static_cast(args[0].cast()) == + DiscoAction::kDebugSetRegister); } } @@ -215,17 +217,17 @@ class RemoteSocketSession { SockAddr server_addr{server_host.c_str(), server_port}; Socket::Startup(); if (!socket_.Connect(server_addr)) { - LOG(FATAL) << "Failed to connect to server " << server_addr.AsString() - << ", errno = " << Socket::GetLastErrorCode(); + TVM_FFI_THROW(InternalError) << "Failed to connect to server " << server_addr.AsString() + << ", errno = " << Socket::GetLastErrorCode(); } channel_ = std::make_unique(socket_); ffi::PackedArgs metadata = channel_->Recv(); - ICHECK_EQ(metadata.size(), 4); + TVM_FFI_ICHECK_EQ(metadata.size(), 4); num_nodes_ = metadata[0].cast(); num_workers_per_node_ = metadata[1].cast(); num_groups_ = metadata[2].cast(); node_id_ = metadata[3].cast(); - CHECK_GE(num_local_workers, num_workers_per_node_); + TVM_FFI_ICHECK_GE(num_local_workers, num_workers_per_node_); InitLocalSession(); } @@ -256,7 +258,7 @@ class RemoteSocketSession { return; } default: - LOG(FATAL) << "Invalid action " << static_cast(action); + TVM_FFI_THROW(InternalError) << "Invalid action " << static_cast(action); } } } diff --git a/src/runtime/disco/loader.cc b/src/runtime/disco/loader.cc index d80cf101a0f0..f8d66bf07b10 100644 --- a/src/runtime/disco/loader.cc +++ b/src/runtime/disco/loader.cc @@ -57,7 +57,7 @@ struct ShardInfo { std::unordered_map LoadShardInfoFromStr(const std::string& json_str); ShardInfo::TensorInfo LoadTensorInfoFromJSON(const json::Array& json_tensor_info) { - CHECK_EQ(json_tensor_info.size(), 2) << "ValueError: Invalid tensor info JSON"; + TVM_FFI_CHECK_EQ(json_tensor_info.size(), 2, ValueError) << "Invalid tensor info JSON"; json::Array shape_json = json_tensor_info[0].cast(); int ndim = shape_json.size(); std::vector shape; @@ -186,7 +186,7 @@ ObjectRef ShardLoaderObj::Create(const std::string& path_to_metadata, const std: } else if (const auto f = tvm::ffi::Function::GetGlobal(name)) { n->shard_funcs_[name] = *f; } else { - LOG(FATAL) << "ValueError: Undefined function: " << name; + TVM_FFI_THROW(ValueError) << "Undefined function: " << name; } } n->param_info_.emplace_back(ParamInfo{&file_record, ¶m_record, shard_info}); @@ -219,7 +219,7 @@ std::string GetSiblingPath(const std::string& path, const std::string& filename) if (found != std::string::npos) { return path.substr(0, found + 1) + filename; } - LOG(FATAL) << "ValueError: Cannot find the parent directory: " << path; + TVM_FFI_THROW(ValueError) << "Cannot find the parent directory: " << path; return ""; } @@ -256,18 +256,18 @@ std::tuple ParseParamShardingInfo(const ParamRecord* param) { std::string name = param->name; size_t pos1 = name.rfind("-of-"); - CHECK(pos1 != std::string::npos) + TVM_FFI_ICHECK(pos1 != std::string::npos) << "Attempt to read num_shards from unexpected param name: " << name; size_t pos2 = name.rfind("_shard-", pos1 - 1); - CHECK(pos2 != std::string::npos) + TVM_FFI_ICHECK(pos2 != std::string::npos) << "Attempt to read sharded worker_id from unexpected param name: " << name; int num_shards = std::stoi(name.substr(pos1 + 4)); int worker_id = std::stoi(name.substr(pos2 + 7, pos1 - pos2 - 7)) - 1; - CHECK_GT(num_shards, 1); - CHECK_GE(worker_id, 0); - CHECK_LT(worker_id, num_shards); + TVM_FFI_ICHECK_GT(num_shards, 1); + TVM_FFI_ICHECK_GE(worker_id, 0); + TVM_FFI_ICHECK_LT(worker_id, num_shards); return {num_shards, worker_id}; } @@ -300,8 +300,8 @@ Tensor ShardLoaderObj::Load(int weight_index) const { if (needs_sharding) { ffi::Shape shape = param_info.shard_info.funcs.back().output_info.shape; DataType dtype = param_info.shard_info.funcs.back().output_info.dtype; - ICHECK(shape.size() >= 1 && shape[0] == num_shards) - << "ValueError: The first dimension of the " + TVM_FFI_CHECK(shape.size() >= 1 && shape[0] == num_shards, ValueError) + << "The first dimension of the " << "output shape must be equal to the " << "number of shards, but got: " << shape << " and num_shards = " << num_shards; Tensor recv = Tensor::Empty(ffi::Shape(shape.begin() + 1, shape.end()), dtype, device); @@ -334,7 +334,7 @@ ffi::Array ShardLoaderObj::LoadAll() const { shards.reserve(n); for (int i = 0; i < n; ++i) { std::string param_name = "param_" + std::to_string(i); - ICHECK(this->param_name_to_index_.count(param_name)); + TVM_FFI_ICHECK(this->param_name_to_index_.count(param_name)); int shard_id = this->param_name_to_index_.at(param_name); shards.push_back(this->Load(shard_id)); } @@ -347,7 +347,7 @@ Tensor ShardLoaderObj::LoadPresharded(int weight_index) const { int num_shards = worker->num_workers; size_t num_weights = param_info_.size() / num_shards; size_t index = worker_id * num_weights + weight_index; - CHECK(index < param_info_.size()) + TVM_FFI_ICHECK(index < param_info_.size()) << "Loading param " << weight_index << " for shard " << worker_id << " at position " << index << " is out of bounds for the provided ndarray chace."; @@ -356,11 +356,11 @@ Tensor ShardLoaderObj::LoadPresharded(int weight_index) const { const FileRecord* file = shard_info.file; auto [p_num_shards, p_worker_id] = ParseParamShardingInfo(param); - CHECK_EQ(num_shards, p_num_shards) + TVM_FFI_ICHECK_EQ(num_shards, p_num_shards) << "Runtime number of shards (" << num_shards << ") does not match number of compiled shards (" << p_num_shards << "): " << param->name << " loaded from " << file->data_path; - CHECK_EQ(worker_id, p_worker_id) + TVM_FFI_ICHECK_EQ(worker_id, p_worker_id) << "Runtime worker_id (" << worker_id << ") does not match worker_id of compiled shard (" << p_worker_id << "): " << param->name << " loaded from " << file->data_path; @@ -382,7 +382,7 @@ ffi::Array ShardLoaderObj::LoadAllPresharded() const { .str(); auto it = param_name_to_index_.find(param_name); - CHECK(it != param_name_to_index_.end()) + TVM_FFI_ICHECK(it != param_name_to_index_.end()) << "Parameter " << param_name << " was not found in the parameter set"; int param_id = this->param_name_to_index_.at(param_name); params.push_back(this->LoadDirect(param_id)); @@ -397,36 +397,36 @@ TVM_FFI_STATIC_INIT_BLOCK() { .def("runtime.disco.ShardLoaderLoad", [](ObjectRef loader_obj, ffi::Shape weight_index) { const auto* loader = loader_obj.as(); - CHECK(loader != nullptr) - << "TypeError: Expected ShardLoaderObj, but gets: " << loader_obj->GetTypeKey(); + TVM_FFI_CHECK(loader != nullptr, TypeError) + << "Expected ShardLoaderObj, but gets: " << loader_obj->GetTypeKey(); return loader->Load(IntegerFromShape(weight_index)); }) .def("runtime.disco.ShardLoaderLoadPresharded", [](ObjectRef loader_obj, ffi::Shape weight_index) { const auto* loader = loader_obj.as(); - CHECK(loader != nullptr) - << "TypeError: Expected ShardLoaderObj, but gets: " << loader_obj->GetTypeKey(); + TVM_FFI_CHECK(loader != nullptr, TypeError) + << "Expected ShardLoaderObj, but gets: " << loader_obj->GetTypeKey(); return loader->LoadPresharded(IntegerFromShape(weight_index)); }) .def("runtime.disco.ShardLoaderLoadAll", [](ObjectRef loader_obj) { const auto* loader = loader_obj.as(); - CHECK(loader != nullptr) - << "TypeError: Expected ShardLoaderObj, but gets: " << loader_obj->GetTypeKey(); + TVM_FFI_CHECK(loader != nullptr, TypeError) + << "Expected ShardLoaderObj, but gets: " << loader_obj->GetTypeKey(); return loader->LoadAll(); }) .def("runtime.disco.ShardLoaderLoadAllPresharded", [](ObjectRef loader_obj) { const auto* loader = loader_obj.as(); - CHECK(loader != nullptr) - << "TypeError: Expected ShardLoaderObj, but gets: " << loader_obj->GetTypeKey(); + TVM_FFI_CHECK(loader != nullptr, TypeError) + << "Expected ShardLoaderObj, but gets: " << loader_obj->GetTypeKey(); return loader->LoadAllPresharded(); }) .def("runtime.disco.ShardLoaderLoadParamOnWorker0", [](ObjectRef loader_obj, int param_index) { const auto* loader = loader_obj.as(); - CHECK(loader != nullptr) - << "TypeError: Expected ShardLoaderObj, but gets: " << loader_obj->GetTypeKey(); + TVM_FFI_CHECK(loader != nullptr, TypeError) + << "Expected ShardLoaderObj, but gets: " << loader_obj->GetTypeKey(); return loader->LoadParamOnWorker0(param_index); }); } diff --git a/src/runtime/disco/message_queue.h b/src/runtime/disco/message_queue.h index 2f065f131c28..095790c6cb46 100644 --- a/src/runtime/disco/message_queue.h +++ b/src/runtime/disco/message_queue.h @@ -84,12 +84,12 @@ class DiscoStreamMessageQueue : private support::Stream, return true; } - ICHECK_EQ(read_size, sizeof(packet_nbytes)) + TVM_FFI_ICHECK_EQ(read_size, sizeof(packet_nbytes)) << "Stream closed without proper shutdown. Please make sure to explicitly call " "`Session::Shutdown`"; read_buffer_.resize(packet_nbytes); read_size = stream_->Read(read_buffer_.data(), packet_nbytes); - ICHECK_EQ(read_size, packet_nbytes) + TVM_FFI_ICHECK_EQ(read_size, packet_nbytes) << "Stream closed without proper shutdown. Please make sure to explicitly call " "`Session::Shutdown`"; read_offset_ = 0; @@ -102,7 +102,7 @@ class DiscoStreamMessageQueue : private support::Stream, size_t Read(void* data, size_t size) final { std::memcpy(data, read_buffer_.data() + read_offset_, size); read_offset_ += size; - ICHECK_LE(read_offset_, read_buffer_.size()); + TVM_FFI_ICHECK_LE(read_offset_, read_buffer_.size()); return size; } diff --git a/src/runtime/disco/nccl/nccl.cc b/src/runtime/disco/nccl/nccl.cc index fd4ad06c3fa8..1230cf15f8a7 100644 --- a/src/runtime/disco/nccl/nccl.cc +++ b/src/runtime/disco/nccl/nccl.cc @@ -50,7 +50,7 @@ inline ncclRedOp_t AsNCCLRedOp(ReduceKind kind) { case ReduceKind::kAvg: return ncclAvg; } - LOG(FATAL) << "ValueError: Unknown ReduceKind: " << static_cast(kind); + TVM_FFI_THROW(ValueError) << "Unknown ReduceKind: " << static_cast(kind); throw; } @@ -65,24 +65,24 @@ void InitCCL(Session sess, ffi::Shape device_ids) { void InitCCLPerWorker(ffi::Shape device_ids, std::string unique_id_bytes) { CCLThreadLocalContext* ctx = CCLThreadLocalContext::Get(); DiscoWorker* worker = DiscoWorker::ThreadLocal(); - ICHECK(worker != nullptr); + TVM_FFI_ICHECK(worker != nullptr); - CHECK_EQ(unique_id_bytes.size(), NCCL_UNIQUE_ID_BYTES) - << "ValueError: The length of unique_id must be " << NCCL_UNIQUE_ID_BYTES << ", but got " + TVM_FFI_CHECK_EQ(unique_id_bytes.size(), NCCL_UNIQUE_ID_BYTES, ValueError) + << "The length of unique_id must be " << NCCL_UNIQUE_ID_BYTES << ", but got " << unique_id_bytes.size() << "."; - CHECK(!ctx->global_comm) << "Cannot initialize CCL, " - << "the previous thread-global comm still exists, " - << "and has not been destructed"; - CHECK(!ctx->group_comm) << "Cannot initialize CCL, " - << "the previous thread-group comm still exists, " - << "and has not been destructed"; - CHECK(!ctx->default_stream) << "Cannot initialize CCL, " - << "the previous thread-global stream still exists, " - << "and has not been destructed"; - CHECK(!ctx->worker) << "Cannot initialize CCL, " - << "the previous thread-global worker still exists, " - << "and has not been destructed"; + TVM_FFI_ICHECK(!ctx->global_comm) << "Cannot initialize CCL, " + << "the previous thread-global comm still exists, " + << "and has not been destructed"; + TVM_FFI_ICHECK(!ctx->group_comm) << "Cannot initialize CCL, " + << "the previous thread-group comm still exists, " + << "and has not been destructed"; + TVM_FFI_ICHECK(!ctx->default_stream) << "Cannot initialize CCL, " + << "the previous thread-global stream still exists, " + << "and has not been destructed"; + TVM_FFI_ICHECK(!ctx->worker) << "Cannot initialize CCL, " + << "the previous thread-global worker still exists, " + << "and has not been destructed"; // Step up local context of NCCL int group_size = worker->num_workers / worker->num_groups; @@ -95,8 +95,8 @@ void InitCCLPerWorker(ffi::Shape device_ids, std::string unique_id_bytes) { if (worker->default_device.device_type == DLDeviceType::kDLCPU) { worker->default_device = device; } else { - ICHECK(worker->default_device.device_type == device.device_type && - worker->default_device.device_id == device.device_id) + TVM_FFI_ICHECK(worker->default_device.device_type == device.device_type && + worker->default_device.device_id == device.device_id) << "The default device of the worker is inconsistent with the device used for CCL. " << "The default device is " << worker->default_device << ", but the device used for CCL is " << device << "."; @@ -123,7 +123,8 @@ void AllReduce(Tensor send, ReduceKind reduce_kind, bool in_group, Tensor recv) deviceStream_t stream = ctx->GetDefaultStream(); DataType dtype = DataType(send->dtype); if (dtype == DataType::Float8E4M3FN() || dtype == DataType::Float8E5M2()) { - LOG(FATAL) << "Float8 data type cannot be allreduced, as nccl does not support this data type."; + TVM_FFI_THROW(InternalError) + << "Float8 data type cannot be allreduced, as nccl does not support this data type."; } NCCL_CALL(ncclAllReduce(send->data, recv->data, numel, /*datatype=*/AsNCCLDataType(dtype), @@ -149,8 +150,8 @@ void BroadcastFromWorker0(ffi::Optional send, bool in_group, Tensor recv const void* send_data = [&]() -> const void* { if (is_sender) { - CHECK(send.defined()); - CHECK(send.value().Shape().Product() == recv.Shape().Product()); + TVM_FFI_ICHECK(send.defined()); + TVM_FFI_ICHECK(send.value().Shape().Product() == recv.Shape().Product()); return send.value()->data; } else { return nullptr; @@ -165,7 +166,7 @@ void BroadcastFromWorker0(ffi::Optional send, bool in_group, Tensor recv } void ScatterFromWorker0(ffi::Optional send, bool in_group, Tensor recv) { - CHECK(recv.defined()) << "ValueError: buffer `recv` must not be None"; + TVM_FFI_CHECK(recv.defined(), ValueError) << "buffer `recv` must not be None"; CCLThreadLocalContext* ctx = CCLThreadLocalContext::Get(); int worker_id = ctx->worker->worker_id; int num_workers = ctx->worker->num_workers; @@ -174,18 +175,20 @@ void ScatterFromWorker0(ffi::Optional send, bool in_group, Tensor recv) int num_receiver = in_group ? group_size : num_workers; deviceStream_t stream = ctx->GetDefaultStream(); if (is_sender) { - CHECK(send.defined()) << "ValueError: buffer `send` must be provided when worker_id == 0."; + TVM_FFI_CHECK(send.defined(), ValueError) + << "buffer `send` must be provided when worker_id == 0."; Tensor buffer = send.value(); int64_t numel = buffer.Shape().Product(); - CHECK_EQ(numel % num_receiver, 0) << "ValueError: Scattering evenly requires that the number " - "of elements in the buffer to be " - "divisible by the number of workers, but got numel = " - << numel << " and " << num_receiver << " workers."; + TVM_FFI_CHECK_EQ(numel % num_receiver, 0, ValueError) + << "Scattering evenly requires that the number " + "of elements in the buffer to be " + "divisible by the number of workers, but got numel = " + << numel << " and " << num_receiver << " workers."; DataType dtype(buffer->dtype); int64_t numel_per_shard = numel / num_receiver; int64_t bytes_per_shard = numel_per_shard * dtype.bytes(); - CHECK_EQ(numel_per_shard, recv.Shape().Product()) - << "ValueError: The number of elements in buffer `recv` must be the same as each shard " + TVM_FFI_CHECK_EQ(numel_per_shard, recv.Shape().Product(), ValueError) + << "The number of elements in buffer `recv` must be the same as each shard " "of " "buffer `send`. `send.size` is " << numel << ", but `recv.size` is " << recv.Shape().Product() << "."; @@ -212,7 +215,7 @@ void ScatterFromWorker0(ffi::Optional send, bool in_group, Tensor recv) } void GatherToWorker0(Tensor send, bool in_group, ffi::Optional recv) { - CHECK(send.defined()) << "ValueError: buffer `send` must not be None"; + TVM_FFI_CHECK(send.defined(), ValueError) << "buffer `send` must not be None"; CCLThreadLocalContext* ctx = CCLThreadLocalContext::Get(); int worker_id = ctx->worker->worker_id; int num_workers = ctx->worker->num_workers; @@ -221,18 +224,20 @@ void GatherToWorker0(Tensor send, bool in_group, ffi::Optional recv) { int num_receiver = in_group ? group_size : num_workers; deviceStream_t stream = ctx->GetDefaultStream(); if (is_sender) { - CHECK(recv.defined()) << "ValueError: buffer `recv` must be provided when worker_id == 0."; + TVM_FFI_CHECK(recv.defined(), ValueError) + << "buffer `recv` must be provided when worker_id == 0."; Tensor buffer = recv.value(); int64_t numel = buffer.Shape().Product(); - CHECK_EQ(numel % num_receiver, 0) << "ValueError: Gathering evenly requires that the number " - "of elements in the buffer to be " - "divisible by the number of workers, but got numel = " - << numel << " and " << num_receiver << " workers."; + TVM_FFI_CHECK_EQ(numel % num_receiver, 0, ValueError) + << "Gathering evenly requires that the number " + "of elements in the buffer to be " + "divisible by the number of workers, but got numel = " + << numel << " and " << num_receiver << " workers."; DataType dtype(buffer->dtype); int64_t numel_per_shard = numel / num_receiver; int64_t bytes_per_shard = numel_per_shard * dtype.bytes(); - CHECK_EQ(numel_per_shard, send.Shape().Product()) - << "ValueError: The number of elements in buffer `send` must be the same as each shard " + TVM_FFI_CHECK_EQ(numel_per_shard, send.Shape().Product(), ValueError) + << "The number of elements in buffer `send` must be the same as each shard " "of " "buffer `recv`. `recv.size` is " << numel << ", but `send.size` is " << send.Shape().Product() << "."; @@ -261,8 +266,8 @@ void GatherToWorker0(Tensor send, bool in_group, ffi::Optional recv) { void RecvFromWorker0(Tensor buffer) { CCLThreadLocalContext* ctx = CCLThreadLocalContext::Get(); deviceStream_t stream = ctx->GetDefaultStream(); - CHECK_NE(ctx->worker->worker_id, 0) - << "ValueError: Worker 0 is not allowed to call RecvFromWorker0."; + TVM_FFI_CHECK_NE(ctx->worker->worker_id, 0, ValueError) + << "Worker 0 is not allowed to call RecvFromWorker0."; NCCL_CALL(ncclGroupStart()); NCCL_CALL(ncclRecv(buffer->data, buffer.Shape().Product(), AsNCCLDataType(buffer.DataType()), 0, ctx->global_comm, stream)); @@ -275,7 +280,7 @@ void SendToNextGroup(Tensor buffer) { int worker_id = ctx->worker->worker_id; int group_size = ctx->worker->num_workers / ctx->worker->num_groups; int receiver_id = worker_id + group_size; - CHECK_LT(receiver_id, ctx->worker->num_workers) + TVM_FFI_ICHECK_LT(receiver_id, ctx->worker->num_workers) << "The current group is already the last group and there is no such a next group."; NCCL_CALL(ncclGroupStart()); NCCL_CALL(ncclSend(buffer->data, buffer.Shape().Product(), AsNCCLDataType(buffer.DataType()), @@ -289,7 +294,7 @@ void RecvFromPrevGroup(Tensor buffer) { int worker_id = ctx->worker->worker_id; int group_size = ctx->worker->num_workers / ctx->worker->num_groups; int sender_id = worker_id - group_size; - CHECK_GE(sender_id, 0) + TVM_FFI_ICHECK_GE(sender_id, 0) << "The current group is already the first group and there is no such a previous group."; NCCL_CALL(ncclGroupStart()); NCCL_CALL(ncclRecv(buffer->data, buffer.Shape().Product(), AsNCCLDataType(buffer.DataType()), @@ -301,10 +306,10 @@ void SendToWorker(Tensor buffer, int receiver_id) { CCLThreadLocalContext* ctx = CCLThreadLocalContext::Get(); deviceStream_t stream = ctx->GetDefaultStream(); int worker_id = ctx->worker->worker_id; - CHECK(receiver_id >= 0 && receiver_id < ctx->worker->num_workers) + TVM_FFI_ICHECK(receiver_id >= 0 && receiver_id < ctx->worker->num_workers) << "Invalid receiver id " << receiver_id << ". The world size is " << ctx->worker->num_workers; - CHECK_NE(worker_id, receiver_id) << "Cannot send to worker itself."; + TVM_FFI_ICHECK_NE(worker_id, receiver_id) << "Cannot send to worker itself."; NCCL_CALL(ncclSend(buffer->data, buffer.Shape().Product(), AsNCCLDataType(buffer.DataType()), receiver_id, ctx->global_comm, stream)); } @@ -313,16 +318,16 @@ void RecvFromWorker(Tensor buffer, int sender_id) { CCLThreadLocalContext* ctx = CCLThreadLocalContext::Get(); deviceStream_t stream = ctx->GetDefaultStream(); int worker_id = ctx->worker->worker_id; - CHECK(sender_id >= 0 && sender_id < ctx->worker->num_workers) + TVM_FFI_ICHECK(sender_id >= 0 && sender_id < ctx->worker->num_workers) << "Invalid sender id " << sender_id << ". The world size is " << ctx->worker->num_workers; - CHECK_NE(worker_id, sender_id) << "Cannot receive from the worker itself."; + TVM_FFI_ICHECK_NE(worker_id, sender_id) << "Cannot receive from the worker itself."; NCCL_CALL(ncclRecv(buffer->data, buffer.Shape().Product(), AsNCCLDataType(buffer.DataType()), sender_id, ctx->global_comm, stream)); } void SyncWorker() { CCLThreadLocalContext* ctx = CCLThreadLocalContext::Get(); - ICHECK(ctx->worker != nullptr); + TVM_FFI_ICHECK(ctx->worker != nullptr); deviceStream_t stream = ctx->GetDefaultStream(); StreamSynchronize(stream); } @@ -335,7 +340,7 @@ TVM_FFI_STATIC_INIT_BLOCK() { .def("runtime.disco." TVM_DISCO_CCL_NAME ".init_ccl_per_worker", InitCCLPerWorker) .def("runtime.disco." TVM_DISCO_CCL_NAME ".allreduce", [](Tensor send, int kind, bool in_group, Tensor recv) { - CHECK(0 <= kind && kind <= 4) << "ValueError: Unknown ReduceKind: " << kind; + TVM_FFI_CHECK(0 <= kind && kind <= 4, ValueError) << "Unknown ReduceKind: " << kind; nccl::AllReduce(send, static_cast(kind), in_group, recv); }) .def("runtime.disco." TVM_DISCO_CCL_NAME ".allgather", @@ -352,8 +357,10 @@ TVM_FFI_STATIC_INIT_BLOCK() { .def("runtime.disco." TVM_DISCO_CCL_NAME ".test_send_to_next_group_recv_from_prev_group", [](Tensor buffer) { CCLThreadLocalContext* ctx = CCLThreadLocalContext::Get(); - CHECK_EQ(ctx->worker->num_workers, 4) << "The test requires the world size to be 4."; - CHECK_EQ(ctx->worker->num_groups, 2) << "The test requires the group size to be 2."; + TVM_FFI_ICHECK_EQ(ctx->worker->num_workers, 4) + << "The test requires the world size to be 4."; + TVM_FFI_ICHECK_EQ(ctx->worker->num_groups, 2) + << "The test requires the group size to be 2."; int group_size = ctx->worker->num_workers / ctx->worker->num_groups; int group_id = ctx->worker->worker_id / group_size; if (group_id == 0) { @@ -364,8 +371,10 @@ TVM_FFI_STATIC_INIT_BLOCK() { }) .def("runtime.disco." TVM_DISCO_CCL_NAME ".test_worker2_sends_to_worker0", [](Tensor buffer) { CCLThreadLocalContext* ctx = CCLThreadLocalContext::Get(); - CHECK_EQ(ctx->worker->num_workers, 4) << "The test requires the world size to be 4."; - CHECK_EQ(ctx->worker->num_groups, 2) << "The test requires the group size to be 2."; + TVM_FFI_ICHECK_EQ(ctx->worker->num_workers, 4) + << "The test requires the world size to be 4."; + TVM_FFI_ICHECK_EQ(ctx->worker->num_groups, 2) + << "The test requires the group size to be 2."; if (ctx->worker->worker_id == 2) { tvm::runtime::nccl::SendToWorker(buffer, 0); } else if (ctx->worker->worker_id == 0) { diff --git a/src/runtime/disco/nccl/nccl_context.h b/src/runtime/disco/nccl/nccl_context.h index e24687d8675f..e18137f42d18 100644 --- a/src/runtime/disco/nccl/nccl_context.h +++ b/src/runtime/disco/nccl/nccl_context.h @@ -47,12 +47,12 @@ namespace tvm { namespace runtime { namespace nccl { -#define NCCL_CALL(cmd) \ - do { \ - auto r = (cmd); \ - if (r != ncclSuccess) { \ - LOG(FATAL) << TVM_DISCO_CCL_NAME "Errror: " << ncclGetErrorString(r); \ - } \ +#define NCCL_CALL(cmd) \ + do { \ + auto r = (cmd); \ + if (r != ncclSuccess) { \ + TVM_FFI_THROW(InternalError) << TVM_DISCO_CCL_NAME "Errror: " << ncclGetErrorString(r); \ + } \ } while (0) #if TVM_NCCL_RCCL_SWITCH == 0 @@ -116,7 +116,7 @@ inline ncclDataType_t AsNCCLDataType(runtime::DataType dtype) { if (dtype == DataType::BFloat(16)) { return ncclBfloat16; } - LOG(FATAL) << "ValueError: Unsupported data type " << dtype; + TVM_FFI_THROW(ValueError) << "Unsupported data type " << dtype; throw; } @@ -150,7 +150,7 @@ struct CCLThreadLocalContext { deviceStream_t GetDefaultStream() { const auto func = tvm::ffi::Function::GetGlobal("runtime.get_" TVM_DISCO_DEVICE_NAME "_stream"); - ICHECK(func.has_value()); + TVM_FFI_ICHECK(func.has_value()); deviceStream_t stream = static_cast((*func)().cast()); return stream == nullptr ? default_stream : stream; } diff --git a/src/runtime/disco/process_session.cc b/src/runtime/disco/process_session.cc index c13cd9e60e9d..36120e912b8c 100644 --- a/src/runtime/disco/process_session.cc +++ b/src/runtime/disco/process_session.cc @@ -71,8 +71,9 @@ class ProcessSessionObj final : public BcastSessionObj { write_fds.reserve(num_workers - 1); for (int i = 1; i < num_workers; ++i) { ffi::Shape fds = process_pool(i).cast(); - CHECK_EQ(fds.size(), 2) << "ValueError: process_pool(" << i << ") should return a tuple of " - << "size 2, but got a tuple of size " << fds.size() << "."; + TVM_FFI_CHECK_EQ(fds.size(), 2, ValueError) + << "process_pool(" << i << ") should return a tuple of " + << "size 2, but got a tuple of size " << fds.size() << "."; read_fds.push_back(fds[0]); write_fds.push_back(fds[1]); } @@ -106,8 +107,9 @@ class ProcessSessionObj final : public BcastSessionObj { workers_[worker_id - 1]->Send(ffi::PackedArgs(packed_args, 3)); } ffi::PackedArgs args = this->RecvReplyPacked(worker_id); - ICHECK_EQ(args.size(), 2); - ICHECK(static_cast(args[0].cast()) == DiscoAction::kDebugGetFromRemote); + TVM_FFI_ICHECK_EQ(args.size(), 2); + TVM_FFI_ICHECK(static_cast(args[0].cast()) == + DiscoAction::kDebugGetFromRemote); ffi::Any result; result = args[1]; return result; @@ -132,8 +134,8 @@ class ProcessSessionObj final : public BcastSessionObj { } ffi::Any result; ffi::PackedArgs args = this->RecvReplyPacked(worker_id); - ICHECK_EQ(args.size(), 1); - ICHECK(static_cast(args[0].cast()) == DiscoAction::kDebugSetRegister); + TVM_FFI_ICHECK_EQ(args.size(), 1); + TVM_FFI_ICHECK(static_cast(args[0].cast()) == DiscoAction::kDebugSetRegister); } void BroadcastPacked(const ffi::PackedArgs& args) final { @@ -173,11 +175,11 @@ class ProcessSessionObj final : public BcastSessionObj { Session Session::ProcessSession(int num_workers, int num_group, ffi::String process_pool_creator, ffi::String entrypoint) { - CHECK_EQ(num_workers % num_group, 0) + TVM_FFI_ICHECK_EQ(num_workers % num_group, 0) << "The number of workers should be divisible by the number of worker group."; const auto pf = tvm::ffi::Function::GetGlobal(process_pool_creator); - CHECK(pf) << "ValueError: Cannot find function " << process_pool_creator - << " in the registry. Please check if it is registered."; + TVM_FFI_CHECK(pf, ValueError) << "Cannot find function " << process_pool_creator + << " in the registry. Please check if it is registered."; auto process_pool = (*pf)(num_workers, num_group, entrypoint).cast(); auto n = ffi::make_object(num_workers, num_group, process_pool); return Session(n); @@ -185,7 +187,7 @@ Session Session::ProcessSession(int num_workers, int num_group, ffi::String proc void WorkerProcess(int worker_id, int num_workers, int num_group, int64_t read_fd, int64_t write_fd) { - CHECK_EQ(num_workers % num_group, 0) + TVM_FFI_ICHECK_EQ(num_workers % num_group, 0) << "The number of workers should be divisible by the number of worker group."; DiscoProcessChannel channel(read_fd, write_fd); DiscoWorker worker(worker_id, num_workers, num_group, nullptr, &channel); diff --git a/src/runtime/disco/protocol.h b/src/runtime/disco/protocol.h index e84b9df04dbb..f3c1fcf35a5f 100644 --- a/src/runtime/disco/protocol.h +++ b/src/runtime/disco/protocol.h @@ -71,7 +71,7 @@ struct DiscoProtocol { /*! \brief Callback method when an error occurs in (de)-serialization. Used by RPCReference. */ void ThrowError(RPCServerStatus status) { - LOG(FATAL) << "InternalError: Unexpected error in RPC: " << RPCServerStatusToString(status); + TVM_FFI_THROW(InternalError) << "Unexpected error in RPC: " << RPCServerStatusToString(status); } /*!\ brief Arena used by RPCReference to allocate POD memory */ @@ -137,9 +137,9 @@ inline uint64_t DiscoProtocol::GetFFIAnyProtocolBytes(const TVMFFI } else if (const auto opt_debug_obj = any_view_ptr->as()) { return sizeof(uint32_t) + (*opt_debug_obj).GetFFIAnyProtocolBytes(); } else { - LOG(FATAL) << "ValueError: Object type is not supported in Disco calling convention: " - << any_view_ptr->GetTypeKey() << " (type_index = " << any_view_ptr->type_index() - << ")"; + TVM_FFI_THROW(ValueError) << "Object type is not supported in Disco calling convention: " + << any_view_ptr->GetTypeKey() + << " (type_index = " << any_view_ptr->type_index() << ")"; return 0; } } @@ -169,9 +169,9 @@ inline void DiscoProtocol::WriteFFIAny(const TVMFFIAny* value) { self->template Write(str.size()); self->template WriteArray(str.data(), str.size()); } else { - LOG(FATAL) << "ValueError: Object type is not supported in Disco calling convention: " - << any_view_ptr->GetTypeKey() << " (type_index = " << any_view_ptr->type_index() - << ")"; + TVM_FFI_THROW(ValueError) << "Object type is not supported in Disco calling convention: " + << any_view_ptr->GetTypeKey() + << " (type_index = " << any_view_ptr->type_index() << ")"; } } @@ -211,8 +211,9 @@ inline void DiscoProtocol::ReadFFIAny(TVMFFIAny* out) { self->template ReadArray(data.data(), size); result = DiscoDebugObject::LoadFromStr(std::move(data))->data.cast(); } else { - LOG(FATAL) << "ValueError: Object type is not supported in Disco calling convention: " - << Object::TypeIndex2Key(type_index) << " (type_index = " << type_index << ")"; + TVM_FFI_THROW(ValueError) << "Object type is not supported in Disco calling convention: " + << Object::TypeIndex2Key(type_index) + << " (type_index = " << type_index << ")"; } *reinterpret_cast(out) = result; any_arena_.push_back(result); @@ -233,36 +234,36 @@ inline std::string DiscoDebugObject::SaveToStr() const { } else if (auto opt_obj = this->data.as()) { ObjectRef obj = opt_obj.value(); const auto f = tvm::ffi::Function::GetGlobal("node.SaveJSON"); - CHECK(f.has_value()) << "ValueError: Cannot serialize object in non-debugging mode: " - << obj->GetTypeKey(); + TVM_FFI_CHECK(f.has_value(), ValueError) + << "Cannot serialize object in non-debugging mode: " << obj->GetTypeKey(); std::string result = (*f)(obj).cast(); result.push_back('0'); return result; } - LOG(FATAL) << "ValueError: Cannot serialize the following type code in non-debugging mode: " - << this->data.GetTypeKey(); + TVM_FFI_THROW(ValueError) << "Cannot serialize the following type code in non-debugging mode: " + << this->data.GetTypeKey(); return ""; } inline ObjectPtr DiscoDebugObject::LoadFromStr(std::string json_str) { - ICHECK(!json_str.empty()); + TVM_FFI_ICHECK(!json_str.empty()); char control_bit = json_str.back(); json_str.pop_back(); ObjectPtr result = ffi::make_object(); if (control_bit == '0') { const auto f = tvm::ffi::Function::GetGlobal("node.LoadJSON"); - CHECK(f.has_value()) << "ValueError: Cannot deserialize object in non-debugging mode"; + TVM_FFI_CHECK(f.has_value(), ValueError) << "Cannot deserialize object in non-debugging mode"; result->data = (*f)(json_str); } else if (control_bit == '1') { support::BytesInStream mstrm(json_str); support::Base64InStream b64strm(&mstrm); b64strm.InitPosition(); runtime::Tensor array; - ICHECK(array.Load(&b64strm)); + TVM_FFI_ICHECK(array.Load(&b64strm)); result->data = std::move(array); } else { - LOG(FATAL) << "ValueError: Unsupported control bit: " << control_bit - << ". Full string: " << json_str; + TVM_FFI_THROW(ValueError) << "Unsupported control bit: " << control_bit + << ". Full string: " << json_str; } return result; } diff --git a/src/runtime/disco/threaded_session.cc b/src/runtime/disco/threaded_session.cc index 398c5885e830..7b92d3b93f9d 100644 --- a/src/runtime/disco/threaded_session.cc +++ b/src/runtime/disco/threaded_session.cc @@ -93,7 +93,7 @@ class DiscoThreadedMessageQueue : private support::Stream, size_t Read(void* data, size_t size) final { std::memcpy(data, read_buffer_.data() + read_offset_, size); read_offset_ += size; - ICHECK_LE(read_offset_, read_buffer_.size()); + TVM_FFI_ICHECK_LE(read_offset_, read_buffer_.size()); return size; } @@ -188,7 +188,7 @@ class ThreadedSessionObj final : public BcastSessionObj { }; Session Session::ThreadedSession(int num_workers, int num_group) { - CHECK_EQ(num_workers % num_group, 0) + TVM_FFI_ICHECK_EQ(num_workers % num_group, 0) << "The number of workers should be divisible by the number of worker group."; ObjectPtr n = ffi::make_object(num_workers, num_group); return Session(std::move(n)); diff --git a/src/runtime/disco/utils.h b/src/runtime/disco/utils.h index fb68335d8c5e..24f0a3496e5e 100644 --- a/src/runtime/disco/utils.h +++ b/src/runtime/disco/utils.h @@ -37,7 +37,8 @@ inline Device UseDefaultDeviceIfNone(ffi::Optional device) { * integers. A common workaround is to use a 1-d shape tuple as an integer. */ inline int64_t IntegerFromShape(const ffi::Shape& shape) { - CHECK_EQ(shape.size(), 1) << "ValueError: shape tuple must be 1-d to be converted to integer."; + TVM_FFI_CHECK_EQ(shape.size(), 1, ValueError) + << "shape tuple must be 1-d to be converted to integer."; return shape[0]; } diff --git a/src/runtime/file_utils.cc b/src/runtime/file_utils.cc index aa042dcf4bb2..4726b09dd2a9 100644 --- a/src/runtime/file_utils.cc +++ b/src/runtime/file_utils.cc @@ -149,7 +149,7 @@ std::string GetMetaFilePath(const std::string& file_name) { void LoadBinaryFromFile(const std::string& file_name, std::string* data) { std::ifstream fs(file_name, std::ios::in | std::ios::binary); - ICHECK(!fs.fail()) << "Cannot open " << file_name; + TVM_FFI_ICHECK(!fs.fail()) << "Cannot open " << file_name; // get its size: fs.seekg(0, std::ios::end); size_t size = static_cast(fs.tellg()); @@ -160,7 +160,7 @@ void LoadBinaryFromFile(const std::string& file_name, std::string* data) { void SaveBinaryToFile(const std::string& file_name, const std::string& data) { std::ofstream fs(file_name, std::ios::out | std::ios::binary); - ICHECK(!fs.fail()) << "Cannot open " << file_name; + TVM_FFI_ICHECK(!fs.fail()) << "Cannot open " << file_name; fs.write(&data[0], data.length()); } @@ -175,7 +175,7 @@ void SaveMetaDataToFile(const std::string& file_name, } root.Set("func_info", std::move(func_info)); std::ofstream fs(file_name.c_str()); - ICHECK(!fs.fail()) << "Cannot open file " << file_name; + TVM_FFI_ICHECK(!fs.fail()) << "Cannot open file " << file_name; fs << std::string(json::Stringify(root, 2)); fs.close(); } @@ -183,7 +183,7 @@ void SaveMetaDataToFile(const std::string& file_name, void LoadMetaDataFromFile(const std::string& file_name, ffi::Map* fmap) { namespace json = ::tvm::ffi::json; std::ifstream fs(file_name.c_str()); - ICHECK(!fs.fail()) << "Cannot open file " << file_name; + TVM_FFI_ICHECK(!fs.fail()) << "Cannot open file " << file_name; std::string content((std::istreambuf_iterator(fs)), std::istreambuf_iterator()); fs.close(); auto root = json::Parse(content).cast(); @@ -203,19 +203,19 @@ void RemoveFile(const std::string& file_name) { void CopyFile(const std::string& src_file_name, const std::string& dest_file_name) { std::ifstream src(src_file_name, std::ios::binary); - ICHECK(src) << "Unable to open source file '" << src_file_name << "'"; + TVM_FFI_ICHECK(src) << "Unable to open source file '" << src_file_name << "'"; std::ofstream dest(dest_file_name, std::ios::binary | std::ios::trunc); - ICHECK(dest) << "Unable to destination source file '" << src_file_name << "'"; + TVM_FFI_ICHECK(dest) << "Unable to destination source file '" << src_file_name << "'"; dest << src.rdbuf(); src.close(); dest.close(); - ICHECK(dest) << "File-copy operation failed." - << " src='" << src_file_name << "'" - << " dest='" << dest_file_name << "'"; + TVM_FFI_ICHECK(dest) << "File-copy operation failed." + << " src='" << src_file_name << "'" + << " dest='" << dest_file_name << "'"; } ffi::Map LoadParams(const std::string& param_blob) { @@ -225,16 +225,16 @@ ffi::Map LoadParams(const std::string& param_blob) { ffi::Map LoadParams(support::Stream* strm) { ffi::Map params; uint64_t header, reserved; - ICHECK(strm->Read(&header)) << "Invalid parameters file format"; - ICHECK(header == kTVMTensorListMagic) << "Invalid parameters file format"; - ICHECK(strm->Read(&reserved)) << "Invalid parameters file format"; + TVM_FFI_ICHECK(strm->Read(&header)) << "Invalid parameters file format"; + TVM_FFI_ICHECK(header == kTVMTensorListMagic) << "Invalid parameters file format"; + TVM_FFI_ICHECK(strm->Read(&reserved)) << "Invalid parameters file format"; std::vector names; - ICHECK(strm->Read(&names)) << "Invalid parameters file format"; + TVM_FFI_ICHECK(strm->Read(&names)) << "Invalid parameters file format"; uint64_t sz; strm->Read(&sz); size_t size = static_cast(sz); - ICHECK(size == names.size()) << "Invalid parameters file format"; + TVM_FFI_ICHECK(size == names.size()) << "Invalid parameters file format"; for (size_t i = 0; i < size; ++i) { // The data_entry is allocated on device, Tensor.load always load the array into CPU. Tensor temp; diff --git a/src/runtime/file_utils.h b/src/runtime/file_utils.h index 4b300c39cfb8..b0a8c9dec902 100644 --- a/src/runtime/file_utils.h +++ b/src/runtime/file_utils.h @@ -139,24 +139,24 @@ struct SimpleBinaryFileStream : public support::Stream { SimpleBinaryFileStream(const std::string& path, std::string mode) { const char* fname = path.c_str(); - CHECK(mode == "wb" || mode == "rb") << "Only allowed modes are 'wb' and 'rb'"; + TVM_FFI_ICHECK(mode == "wb" || mode == "rb") << "Only allowed modes are 'wb' and 'rb'"; read_ = mode == "rb"; fp_ = std::fopen(fname, mode.c_str()); - CHECK(fp_ != nullptr) << "Unable to open file " << path; + TVM_FFI_ICHECK(fp_ != nullptr) << "Unable to open file " << path; } virtual ~SimpleBinaryFileStream(void) { this->Close(); } virtual size_t Read(void* ptr, size_t size) { - CHECK(read_) << "File opened in write-mode, cannot read."; - CHECK(fp_ != nullptr) << "File is closed"; + TVM_FFI_ICHECK(read_) << "File opened in write-mode, cannot read."; + TVM_FFI_ICHECK(fp_ != nullptr) << "File is closed"; return std::fread(ptr, 1, size, fp_); } virtual size_t Write(const void* ptr, size_t size) { - CHECK(!read_) << "File opened in read-mode, cannot write."; - CHECK(fp_ != nullptr) << "File is closed"; + TVM_FFI_ICHECK(!read_) << "File opened in read-mode, cannot write."; + TVM_FFI_ICHECK(fp_ != nullptr) << "File is closed"; size_t nwrite = std::fwrite(ptr, 1, size, fp_); int err = std::ferror(fp_); - CHECK_EQ(err, 0) << "SimpleBinaryFileStream.Write incomplete: " << std::strerror(err); + TVM_FFI_ICHECK_EQ(err, 0) << "SimpleBinaryFileStream.Write incomplete: " << std::strerror(err); return nwrite; } inline void Close(void) { diff --git a/src/runtime/hexagon/hexagon_buffer.cc b/src/runtime/hexagon/hexagon_buffer.cc index c6dd9421fe63..aa1453dbe500 100644 --- a/src/runtime/hexagon/hexagon_buffer.cc +++ b/src/runtime/hexagon/hexagon_buffer.cc @@ -49,7 +49,7 @@ struct Allocation { struct DDRAllocation : public Allocation { DDRAllocation(size_t nbytes, size_t alignment) : Allocation(nbytes, alignment) { int ret = posix_memalign(&data_, alignment, nbytes); - CHECK_EQ(ret, 0); + TVM_FFI_ICHECK_EQ(ret, 0); // The heap used by malloc on Hexagon is always mapped as cacheable. The heap manager may not // perform cache invalidation on a prior memory free. So, a subsequent memory allocation request @@ -68,7 +68,7 @@ struct VTCMAllocation : public Allocation { VTCMAllocation(size_t nbytes, size_t alignment) : Allocation(nbytes, alignment) { // For simplicity, the current VTCM dynamic pool supports the following alignments: less than // or equal to 128 (0x80), and 2k (0x800) - CHECK((alignment <= 0x80) || (alignment == 0x800)) + TVM_FFI_ICHECK((alignment <= 0x80) || (alignment == 0x800)) << "VTCMAllocation called for invalid alignment " << alignment; if (alignment == 0x800) { @@ -119,7 +119,7 @@ HexagonBuffer::HexagonBuffer(size_t nbytes, size_t alignment, ffi::Optional(nbytes, alignment); } - CHECK(alloca != nullptr); + TVM_FFI_ICHECK(alloca != nullptr); allocations_.push_back(alloca->data_); managed_allocations_.push_back(std::move(alloca)); } @@ -138,7 +138,7 @@ HexagonBuffer::HexagonBuffer(size_t nallocs, size_t nbytes, size_t alignment, } else if (GetStorageScope() == StorageScope::kVTCM) { alloca = Allocator(nbytes_monolithic, alignment); } - CHECK(alloca) << "could not create allocation"; + TVM_FFI_ICHECK(alloca) << "could not create allocation"; for (size_t i = 0; i < nallocs; ++i) { void* alloc_offset = static_cast(alloca->data_) + i * nbytes_aligned; @@ -151,16 +151,17 @@ HexagonBuffer::HexagonBuffer(size_t nallocs, size_t nbytes, size_t alignment, HexagonBuffer::~HexagonBuffer() { managed_allocations_.clear(); } void* HexagonBuffer::GetPointer() { - ICHECK(allocations_.size()) + TVM_FFI_ICHECK(allocations_.size()) << "Internal failure, allocations_ should be set in HexagonBuffer constructor"; if (ndim_ == 1) { - ICHECK_EQ(allocations_.size(), 1); + TVM_FFI_ICHECK_EQ(allocations_.size(), 1); return allocations_[0]; } else if (ndim_ == 2) { return allocations_.data(); } else { - LOG(FATAL) << "HexagonBuffer should be either 1-d or 2-d, not " << ndim_ << "-d"; + TVM_FFI_THROW(InternalError) << "HexagonBuffer should be either 1-d or 2-d, not " << ndim_ + << "-d"; } } @@ -176,14 +177,14 @@ void HexagonBuffer::SetStorageScope(ffi::Optional scope) { } else if (s == "global.vtcm") { storage_scope_ = StorageScope::kVTCM; } else { - CHECK(false) << "Encountered unknown HexagonBuffer storage scope: " << std::string(s); + TVM_FFI_ICHECK(false) << "Encountered unknown HexagonBuffer storage scope: " << std::string(s); } } std::vector BufferSet::MemoryCopies(const BufferSet& dest, const BufferSet& src, size_t bytes_to_copy) { - CHECK_LE(bytes_to_copy, src.TotalBytes()); - CHECK_LE(bytes_to_copy, dest.TotalBytes()); + TVM_FFI_ICHECK_LE(bytes_to_copy, src.TotalBytes()); + TVM_FFI_ICHECK_LE(bytes_to_copy, dest.TotalBytes()); auto pointer_to = [](const BufferSet& buf, size_t region_i, size_t byte_i) -> void* { void* region = buf.buffers[region_i]; diff --git a/src/runtime/hexagon/hexagon_buffer_manager.h b/src/runtime/hexagon/hexagon_buffer_manager.h index 3c43c25b5863..3f26b2554b30 100644 --- a/src/runtime/hexagon/hexagon_buffer_manager.h +++ b/src/runtime/hexagon/hexagon_buffer_manager.h @@ -48,9 +48,9 @@ class HexagonBufferManager { void FreeHexagonBuffer(void* ptr) { std::lock_guard lock(map_mutex_); auto it = hexagon_buffer_map_.find(ptr); - CHECK(it != hexagon_buffer_map_.end()) + TVM_FFI_ICHECK(it != hexagon_buffer_map_.end()) << "Attempt made to free unknown or already freed allocation"; - CHECK(it->second != nullptr); + TVM_FFI_ICHECK(it->second != nullptr); hexagon_buffer_map_.erase(it); } /*! diff --git a/src/runtime/hexagon/hexagon_device_api.cc b/src/runtime/hexagon/hexagon_device_api.cc index 662a2209592d..0d1c432571c6 100644 --- a/src/runtime/hexagon/hexagon_device_api.cc +++ b/src/runtime/hexagon/hexagon_device_api.cc @@ -52,8 +52,9 @@ void HexagonDeviceAPI::GetAttr(Device dev, DeviceAttrKind kind, ffi::Any* rv) { // DataSpace: static allocations for Hexagon void* HexagonDeviceAPI::AllocDataSpace(Device dev, int ndim, const int64_t* shape, DLDataType dtype, ffi::Optional mem_scope) { - CHECK(shape || ndim == 0) << "shape array is null for a non-scalar tensor, ndim = " << ndim; - CHECK(IsValidDevice(dev)) << "dev.device_type: " << dev.device_type; + TVM_FFI_ICHECK(shape || ndim == 0) + << "shape array is null for a non-scalar tensor, ndim = " << ndim; + TVM_FFI_ICHECK(IsValidDevice(dev)) << "dev.device_type: " << dev.device_type; // IMPORTANT NOTE! // Hexagon treats "global" memory scope VERY DIFFERENTLY from all the others. @@ -79,18 +80,19 @@ void* HexagonDeviceAPI::AllocDataSpace(Device dev, int ndim, const int64_t* shap // NOTE: This check should be superfluous, but it's probably a good idea to leave it in // until the AoT executor's multi-device dispatch code is mature. --cconvey 2022-08-26 - CHECK(dev.device_type == kDLHexagon) + TVM_FFI_ICHECK(dev.device_type == kDLHexagon) << "dev.device_type: " << dev.device_type << " DeviceName(" << dev.device_type << "): " << DLDeviceType2Str(dev.device_type) << ""; - CHECK(ndim >= 0 && ndim <= 2) + TVM_FFI_ICHECK(ndim >= 0 && ndim <= 2) << "Hexagon Device API supports only 1d and 2d allocations, but received ndim = " << ndim; const size_t typesize = (dtype.bits / 8) * dtype.lanes; - CHECK(runtime_hexbuffs) << "Attempted to allocate Hexagon data with " - << "HexagonDeviceAPI::AllocDataSpace before initializing resources. " - << "Please call HexagonDeviceAPI::AcquireResources"; + TVM_FFI_ICHECK(runtime_hexbuffs) + << "Attempted to allocate Hexagon data with " + << "HexagonDeviceAPI::AllocDataSpace before initializing resources. " + << "Please call HexagonDeviceAPI::AcquireResources"; if (ndim == 0) { // Allocate storage for a single scalar value. @@ -112,21 +114,22 @@ void* HexagonDeviceAPI::AllocDataSpace(Device dev, int ndim, const int64_t* shap void* HexagonDeviceAPI::AllocDataSpace(Device dev, size_t nbytes, size_t alignment, DLDataType type_hint) { - CHECK(nbytes) << "number of bytes is zero"; - CHECK(alignment) << "alignment is zero"; - CHECK(IsValidDevice(dev)) << "dev.device_type: " << dev.device_type; + TVM_FFI_ICHECK(nbytes) << "number of bytes is zero"; + TVM_FFI_ICHECK(alignment) << "alignment is zero"; + TVM_FFI_ICHECK(IsValidDevice(dev)) << "dev.device_type: " << dev.device_type; if (alignment < kHexagonAllocAlignment) { alignment = kHexagonAllocAlignment; } - CHECK(runtime_hexbuffs) << "Attempted to allocate Hexagon data with " - << "HexagonDeviceAPI::AllocDataSpace before initializing resources. " - << "Please call HexagonDeviceAPI::AcquireResources"; + TVM_FFI_ICHECK(runtime_hexbuffs) + << "Attempted to allocate Hexagon data with " + << "HexagonDeviceAPI::AllocDataSpace before initializing resources. " + << "Please call HexagonDeviceAPI::AcquireResources"; return runtime_hexbuffs->AllocateHexagonBuffer(nbytes, alignment, ffi::String("global")); } void HexagonDeviceAPI::FreeDataSpace(Device dev, void* ptr) { - CHECK(ptr) << "buffer pointer is null"; - CHECK(IsValidDevice(dev)) << "dev.device_type: " << dev.device_type; + TVM_FFI_ICHECK(ptr) << "buffer pointer is null"; + TVM_FFI_ICHECK(IsValidDevice(dev)) << "dev.device_type: " << dev.device_type; if (runtime_hexbuffs) { runtime_hexbuffs->FreeHexagonBuffer(ptr); } else { @@ -148,27 +151,28 @@ static HexagonWorkspacePool* HexagonWorkspacePoolThreadLocal() { } void* HexagonDeviceAPI::AllocWorkspace(Device dev, size_t size, DLDataType type_hint) { - CHECK(IsValidDevice(dev)) << "dev.device_type: " << dev.device_type; + TVM_FFI_ICHECK(IsValidDevice(dev)) << "dev.device_type: " << dev.device_type; return HexagonWorkspacePoolThreadLocal()->AllocWorkspace(dev, size); } void HexagonDeviceAPI::FreeWorkspace(Device dev, void* data) { - CHECK(IsValidDevice(dev)) << "dev.device_type: " << dev.device_type; - CHECK(runtime_hexbuffs) << "Attempted to free Hexagon workspace with " - << "HexagonDeviceAPI::FreeWorkspace outside of a session. " - << "Please call HexagonDeviceAPI::AcquireResources"; - CHECK(runtime_hexbuffs->FindHexagonBuffer(data) != nullptr) + TVM_FFI_ICHECK(IsValidDevice(dev)) << "dev.device_type: " << dev.device_type; + TVM_FFI_ICHECK(runtime_hexbuffs) << "Attempted to free Hexagon workspace with " + << "HexagonDeviceAPI::FreeWorkspace outside of a session. " + << "Please call HexagonDeviceAPI::AcquireResources"; + TVM_FFI_ICHECK(runtime_hexbuffs->FindHexagonBuffer(data) != nullptr) << "Attempt made to free unknown or already freed workspace allocation"; HexagonWorkspacePoolThreadLocal()->FreeWorkspace(dev, data); } void HexagonDeviceAPI::CopyDataFromTo(DLTensor* from, DLTensor* to, TVMStreamHandle stream) { - CHECK_EQ(from->byte_offset, 0); - CHECK_EQ(to->byte_offset, 0); - CHECK_EQ(GetDataSize(*from), GetDataSize(*to)); - CHECK(runtime_hexbuffs) << "Attempted to copy Hexagon data with " - << "HexagonDeviceAPI::CopyDataFromTo before initializing resources. " - << "Please call HexagonDeviceAPI::AcquireResources"; + TVM_FFI_ICHECK_EQ(from->byte_offset, 0); + TVM_FFI_ICHECK_EQ(to->byte_offset, 0); + TVM_FFI_ICHECK_EQ(GetDataSize(*from), GetDataSize(*to)); + TVM_FFI_ICHECK(runtime_hexbuffs) + << "Attempted to copy Hexagon data with " + << "HexagonDeviceAPI::CopyDataFromTo before initializing resources. " + << "Please call HexagonDeviceAPI::AcquireResources"; auto lookup_hexagon_buffer = [this](void* ptr) -> HexagonBuffer* { return runtime_hexbuffs->FindHexagonBuffer(ptr); @@ -184,8 +188,9 @@ void HexagonDeviceAPI::CopyDataFromTo(DLTensor* from, DLTensor* to, TVMStreamHan } else if (hex_from_buf) { hex_from_buf->CopyTo(to->data, GetDataSize(*to)); } else { - CHECK(false) << "CopyDataFromTo requested between src and dst which are not managed by the " - "hexagon device api."; + TVM_FFI_ICHECK(false) + << "CopyDataFromTo requested between src and dst which are not managed by the " + "hexagon device api."; } } @@ -203,7 +208,7 @@ TVM_FFI_STATIC_INIT_BLOCK() { auto dst = args[0].cast(); auto src = args[1].cast(); int size = args[2].cast(); - ICHECK(size > 0); + TVM_FFI_ICHECK(size > 0); bool bypass_cache = args[3].cast(); int ret = DMA_RETRY; @@ -211,7 +216,7 @@ TVM_FFI_STATIC_INIT_BLOCK() { ret = HexagonDeviceAPI::Global()->UserDMA()->Copy( SYNC_DMA_QUEUE, dst->data, src->data, size, bypass_cache); } while (ret == DMA_RETRY); - CHECK(ret == DMA_SUCCESS); + TVM_FFI_ICHECK(ret == DMA_SUCCESS); HexagonDeviceAPI::Global()->UserDMA()->Wait(SYNC_DMA_QUEUE, 0); *rv = static_cast(0); @@ -222,7 +227,7 @@ TVM_FFI_STATIC_INIT_BLOCK() { void* dst = args[1].cast(); void* src = args[2].cast(); uint32_t size = args[3].cast(); - ICHECK(size > 0); + TVM_FFI_ICHECK(size > 0); bool bypass_cache = args[4].cast(); int ret = DMA_RETRY; @@ -230,14 +235,14 @@ TVM_FFI_STATIC_INIT_BLOCK() { ret = HexagonDeviceAPI::Global()->UserDMA()->Copy(queue_id, dst, src, size, bypass_cache); } while (ret == DMA_RETRY); - CHECK(ret == DMA_SUCCESS); + TVM_FFI_ICHECK(ret == DMA_SUCCESS); *rv = static_cast(ret); }) .def_packed("device_api.hexagon.dma_wait", [](ffi::PackedArgs args, ffi::Any* rv) { uint32_t queue_id = args[0].cast(); int inflight = args[1].cast(); - ICHECK(inflight >= 0); + TVM_FFI_ICHECK(inflight >= 0); HexagonDeviceAPI::Global()->UserDMA()->Wait(queue_id, inflight); *rv = static_cast(0); }) @@ -260,10 +265,10 @@ TVM_FFI_STATIC_INIT_BLOCK() { int32_t dtype_code_hint = args[2].cast(); int32_t dtype_bits_hint = args[3].cast(); auto scope = args[4].cast(); - CHECK(scope.find("global.vtcm") != std::string::npos); + TVM_FFI_ICHECK(scope.find("global.vtcm") != std::string::npos); int64_t ndim = args[5].cast(); - CHECK((ndim == 1 || ndim == 2) && - "Hexagon Device API supports only 1d and 2d allocations"); + TVM_FFI_ICHECK((ndim == 1 || ndim == 2) && + "Hexagon Device API supports only 1d and 2d allocations"); int64_t* shape = static_cast(args[6].cast()); Device dev; @@ -283,7 +288,7 @@ TVM_FFI_STATIC_INIT_BLOCK() { int32_t device_type = args[0].cast(); int32_t device_id = args[1].cast(); auto scope = args[2].cast(); - CHECK(scope.find("global.vtcm") != std::string::npos); + TVM_FFI_ICHECK(scope.find("global.vtcm") != std::string::npos); void* ptr = args[3].cast(); Device dev; diff --git a/src/runtime/hexagon/hexagon_device_api.h b/src/runtime/hexagon/hexagon_device_api.h index 76439ef531ae..c37614f5374f 100644 --- a/src/runtime/hexagon/hexagon_device_api.h +++ b/src/runtime/hexagon/hexagon_device_api.h @@ -56,38 +56,39 @@ class HexagonDeviceAPI final : public DeviceAPI { //! \brief Ensures resource managers are in a good state for the runtime void AcquireResources() { - CHECK_EQ(runtime_power_manager, nullptr); + TVM_FFI_ICHECK_EQ(runtime_power_manager, nullptr); runtime_power_manager = std::make_unique(); - CHECK_EQ(runtime_vtcm, nullptr); + TVM_FFI_ICHECK_EQ(runtime_vtcm, nullptr); runtime_vtcm = std::make_unique(); - CHECK_EQ(runtime_hexbuffs, nullptr); + TVM_FFI_ICHECK_EQ(runtime_hexbuffs, nullptr); runtime_hexbuffs = std::make_unique(); - CHECK_EQ(runtime_threads, nullptr); + TVM_FFI_ICHECK_EQ(runtime_threads, nullptr); runtime_threads = std::make_unique(threads, stack_size, pipe_size, hw_resources); - CHECK_EQ(runtime_dma, nullptr); + TVM_FFI_ICHECK_EQ(runtime_dma, nullptr); runtime_dma = std::make_unique(); } //! \brief Ensures all runtime resources are freed void ReleaseResources() { - CHECK(runtime_dma) << "runtime_dma was not created in AcquireResources"; + TVM_FFI_ICHECK(runtime_dma) << "runtime_dma was not created in AcquireResources"; runtime_dma.reset(); - CHECK(runtime_threads) << "runtime_threads was not created in AcquireResources"; + TVM_FFI_ICHECK(runtime_threads) << "runtime_threads was not created in AcquireResources"; runtime_threads.reset(); - CHECK(runtime_hexbuffs) << "runtime_hexbuffs was not created in AcquireResources"; + TVM_FFI_ICHECK(runtime_hexbuffs) << "runtime_hexbuffs was not created in AcquireResources"; runtime_hexbuffs.reset(); - CHECK(runtime_vtcm) << "runtime_vtcm was not created in AcquireResources"; + TVM_FFI_ICHECK(runtime_vtcm) << "runtime_vtcm was not created in AcquireResources"; runtime_vtcm.reset(); - CHECK(runtime_power_manager) << "runtime_power_manager was not created in AcquireResources"; + TVM_FFI_ICHECK(runtime_power_manager) + << "runtime_power_manager was not created in AcquireResources"; runtime_power_manager.reset(); } @@ -149,17 +150,17 @@ class HexagonDeviceAPI final : public DeviceAPI { void CopyDataFromTo(DLTensor* from, DLTensor* to, TVMStreamHandle stream) final; HexagonThreadManager* ThreadManager() { - CHECK(runtime_threads) << "runtime_threads has not been created"; + TVM_FFI_ICHECK(runtime_threads) << "runtime_threads has not been created"; return runtime_threads.get(); } HexagonUserDMA* UserDMA() { - CHECK(runtime_dma) << "runtime_dma has not been created"; + TVM_FFI_ICHECK(runtime_dma) << "runtime_dma has not been created"; return runtime_dma.get(); } HexagonVtcmPool* VtcmPool() { - CHECK(runtime_vtcm) << "runtime_vtcm has not been created"; + TVM_FFI_ICHECK(runtime_vtcm) << "runtime_vtcm has not been created"; return runtime_vtcm.get(); } diff --git a/src/runtime/hexagon/hexagon_htp.cc b/src/runtime/hexagon/hexagon_htp.cc index ac1b267902c7..c12a319252ba 100644 --- a/src/runtime/hexagon/hexagon_htp.cc +++ b/src/runtime/hexagon/hexagon_htp.cc @@ -52,15 +52,15 @@ void HexagonHtp::Acquire() { int nErr; if ((nErr = HAP_compute_res_attr_init(&compute_res_attr))) { - LOG(FATAL) << "InternalError: HAP_compute_res_attr_init failed\n"; + TVM_FFI_THROW(InternalError) << "HAP_compute_res_attr_init failed\n"; } if ((nErr = HAP_compute_res_attr_set_hmx_param(&compute_res_attr, 1))) { - LOG(FATAL) << "InternalError: HAP_compute_res_attr_set_hmx_param failed\n"; + TVM_FFI_THROW(InternalError) << "HAP_compute_res_attr_set_hmx_param failed\n"; } context_id_ = HAP_compute_res_acquire(&compute_res_attr, COMPUTE_RES_ACQ_TIMEOUT); if (!context_id_) { - LOG(FATAL) << "InternalError: HAP_compute_res_acquire failed\n"; + TVM_FFI_THROW(InternalError) << "HAP_compute_res_acquire failed\n"; } } @@ -70,7 +70,7 @@ void HexagonHtp::Lock() { int nErr; if ((nErr = HAP_compute_res_hmx_lock(context_id_))) { - LOG(FATAL) << "InternalError: Unable to lock HTP!"; + TVM_FFI_THROW(InternalError) << "Unable to lock HTP!"; } } diff --git a/src/runtime/hexagon/hexagon_hvx.cc b/src/runtime/hexagon/hexagon_hvx.cc index 4fc97bf95475..b71ef8b44f55 100644 --- a/src/runtime/hexagon/hexagon_hvx.cc +++ b/src/runtime/hexagon/hexagon_hvx.cc @@ -37,22 +37,23 @@ HexagonHvx::~HexagonHvx() { Release(); } void HexagonHvx::Acquire() { reserved_count_ = qurt_hvx_reserve(QURT_HVX_RESERVE_ALL); - CHECK(reserved_count_ == QURT_HVX_RESERVE_ALL) << "error reserving HVX: " << reserved_count_; + TVM_FFI_ICHECK(reserved_count_ == QURT_HVX_RESERVE_ALL) + << "error reserving HVX: " << reserved_count_; } void HexagonHvx::Release() { int rel = qurt_hvx_cancel_reserve(); - CHECK(rel == 0) << "error releasing HVX: " << rel; + TVM_FFI_ICHECK(rel == 0) << "error releasing HVX: " << rel; } void HexagonHvx::Lock() { int lck = qurt_hvx_lock(QURT_HVX_MODE_128B); - CHECK(lck == 0) << "error locking HVX: " << lck; + TVM_FFI_ICHECK(lck == 0) << "error locking HVX: " << lck; } void HexagonHvx::Unlock() { int unl = qurt_hvx_unlock(); - CHECK(unl == 0) << "error unlocking HVX: " << unl; + TVM_FFI_ICHECK(unl == 0) << "error unlocking HVX: " << unl; } } // namespace hexagon diff --git a/src/runtime/hexagon/hexagon_module.cc b/src/runtime/hexagon/hexagon_module.cc index 4b91fc9a8c23..dd9d74c202a4 100644 --- a/src/runtime/hexagon/hexagon_module.cc +++ b/src/runtime/hexagon/hexagon_module.cc @@ -43,7 +43,7 @@ HexagonModuleNode::HexagonModuleNode(std::string data, std::string fmt, : data_(data), fmt_(fmt), fmap_(fmap), asm_(asm_str), obj_(obj_str), ir_(ir_str), bc_(bc_str) {} ffi::Optional HexagonModuleNode::GetFunction(const ffi::String& name) { - LOG(FATAL) << "HexagonModuleNode::GetFunction is not implemented."; + TVM_FFI_THROW(InternalError) << "HexagonModuleNode::GetFunction is not implemented."; } ffi::String HexagonModuleNode::InspectSource(const ffi::String& format) const { @@ -63,19 +63,20 @@ void HexagonModuleNode::WriteToFile(const ffi::String& file_name, const ffi::Str SaveMetaDataToFile(meta_file, fmap_); CopyFile(data_, file_name); } else if (fmt == "s" || fmt == "asm") { - ICHECK(!asm_.empty()) << "Assembler source not available"; + TVM_FFI_ICHECK(!asm_.empty()) << "Assembler source not available"; SaveBinaryToFile(file_name, asm_); } else if (fmt == "o" || fmt == "obj") { - ICHECK(!obj_.empty()) << "Object data not available"; + TVM_FFI_ICHECK(!obj_.empty()) << "Object data not available"; SaveBinaryToFile(file_name, obj_); } else if (fmt == "ll") { - ICHECK(!ir_.empty()) << "LLVM IR source not available"; + TVM_FFI_ICHECK(!ir_.empty()) << "LLVM IR source not available"; SaveBinaryToFile(file_name, ir_); } else if (fmt == "bc") { - ICHECK(!bc_.empty()) << "LLVM IR bitcode not available"; + TVM_FFI_ICHECK(!bc_.empty()) << "LLVM IR bitcode not available"; SaveBinaryToFile(file_name, bc_); } else { - LOG(FATAL) << "HexagonModuleNode::SaveToFile: unhandled format `" << fmt << "'"; + TVM_FFI_THROW(InternalError) << "HexagonModuleNode::SaveToFile: unhandled format `" << fmt + << "'"; } } diff --git a/src/runtime/hexagon/hexagon_thread_manager.cc b/src/runtime/hexagon/hexagon_thread_manager.cc index a6ae62e39fa5..c1c3eadc3126 100644 --- a/src/runtime/hexagon/hexagon_thread_manager.cc +++ b/src/runtime/hexagon/hexagon_thread_manager.cc @@ -28,15 +28,15 @@ HexagonThreadManager::HexagonThreadManager(unsigned num_threads, unsigned thread const std::vector hw_resources) { // Note: could technically manage more software threads than allowable hardware threads, but // there is no system constant defined in the qurt libs for that maximum. - CHECK(num_threads); - CHECK_LE(num_threads, QURT_MAX_HTHREAD_LIMIT); + TVM_FFI_ICHECK(num_threads); + TVM_FFI_ICHECK_LE(num_threads, QURT_MAX_HTHREAD_LIMIT); nthreads_ = num_threads; - CHECK_GE(thread_stack_size_bytes, MIN_STACK_SIZE_BYTES); - CHECK_LE(thread_stack_size_bytes, MAX_STACK_SIZE_BYTES); + TVM_FFI_ICHECK_GE(thread_stack_size_bytes, MIN_STACK_SIZE_BYTES); + TVM_FFI_ICHECK_LE(thread_stack_size_bytes, MAX_STACK_SIZE_BYTES); - CHECK_GE(thread_pipe_size_words, MIN_PIPE_SIZE_WORDS); - CHECK_LE(thread_pipe_size_words, MAX_PIPE_SIZE_WORDS); + TVM_FFI_ICHECK_GE(thread_pipe_size_words, MIN_PIPE_SIZE_WORDS); + TVM_FFI_ICHECK_LE(thread_pipe_size_words, MAX_PIPE_SIZE_WORDS); hw_resources_ = hw_resources; CheckResources(); @@ -120,7 +120,7 @@ HexagonThreadManager::~HexagonThreadManager() { void HexagonThreadManager::CheckResources() { create_resource_managers_ = false; - CHECK(hw_resources_.empty() || hw_resources_.size() == nthreads_) + TVM_FFI_ICHECK(hw_resources_.empty() || hw_resources_.size() == nthreads_) << "Thread count must match resource count"; if (!hw_resources_.empty()) { // Ensure that no more than one of each hardware resource is specified @@ -128,7 +128,7 @@ void HexagonThreadManager::CheckResources() { if (hw_resources_[i] != NONE) { create_resource_managers_ = true; for (int j = i + 1; j < hw_resources_.size(); j++) { - CHECK(hw_resources_[i] != hw_resources_[j]) + TVM_FFI_ICHECK(hw_resources_[i] != hw_resources_[j]) << "No more than one of each resource type may be specified " << hw_resources_[i]; } } @@ -164,7 +164,7 @@ void HexagonThreadManager::SpawnThreads(unsigned thread_stack_size_bytes, // create the pipe int rc = qurt_pipe_init(&pipes_[i], &pipe_attr); - CHECK_EQ(rc, QURT_EOK); + TVM_FFI_ICHECK_EQ(rc, QURT_EOK); } DLOG(INFO) << "Pipes created"; @@ -186,7 +186,7 @@ void HexagonThreadManager::SpawnThreads(unsigned thread_stack_size_bytes, contexts_[i] = new ThreadContext(&pipes_[i], i, hw_resources_.empty() ? NONE : hw_resources_[i], hvx_.get(), htp_.get()); int rc = qurt_thread_create(&threads_[i], &thread_attr, thread_main, contexts_[i]); - CHECK_EQ(rc, QURT_EOK); + TVM_FFI_ICHECK_EQ(rc, QURT_EOK); } DLOG(INFO) << "Threads created"; @@ -207,11 +207,11 @@ TVMStreamHandle HexagonThreadManager::GetStreamHandleByResourceType(HardwareReso return reinterpret_cast(i); } } - CHECK(false) << "Thread for resource type " << type << " not found"; + TVM_FFI_ICHECK(false) << "Thread for resource type " << type << " not found"; } HardwareResourceType HexagonThreadManager::GetResourceTypeForStreamHandle(TVMStreamHandle thread) { - CHECK(hw_resources_.size() > reinterpret_cast(thread)) + TVM_FFI_ICHECK(hw_resources_.size() > reinterpret_cast(thread)) << "No thread for handle id exists " << thread; return hw_resources_[reinterpret_cast(thread)]; } diff --git a/src/runtime/hexagon/hexagon_thread_manager.h b/src/runtime/hexagon/hexagon_thread_manager.h index 7ec3ac61506d..83c5316a7259 100644 --- a/src/runtime/hexagon/hexagon_thread_manager.h +++ b/src/runtime/hexagon/hexagon_thread_manager.h @@ -156,7 +156,7 @@ class HexagonThreadManager { ThreadContext(qurt_pipe_t* pipe, unsigned index, HardwareResourceType resource_type, HexagonHvx* hvx, HexagonHtp* htp) : pipe(pipe), index(index), resource_type(resource_type), hvx(hvx), htp(htp), status(0) { - CHECK(resource_type == NONE || (hvx && htp)) + TVM_FFI_ICHECK(resource_type == NONE || (hvx && htp)) << "Missing resource manager pointer, type: " << resource_type << " hvx: " << hvx << " htp: " << htp; } diff --git a/src/runtime/hexagon/hexagon_user_dma.cc b/src/runtime/hexagon/hexagon_user_dma.cc index 11214a46e809..dddd85720a73 100644 --- a/src/runtime/hexagon/hexagon_user_dma.cc +++ b/src/runtime/hexagon/hexagon_user_dma.cc @@ -120,7 +120,7 @@ uint32_t HexagonUserDMA::DMAGroupsInFlight(uint32_t queue_id) { HexagonUserDMA::HexagonUserDMA() { // reset DMA engine unsigned int status = Init(); - CHECK_EQ(status, DM0_STATUS_IDLE); + TVM_FFI_ICHECK_EQ(status, DM0_STATUS_IDLE); auto desc_in_flight = [](dma_desc_2d_t* dma_desc) { unsigned int done = dma_desc_get_done(dma_desc); diff --git a/src/runtime/hexagon/hexagon_vtcm_pool.cc b/src/runtime/hexagon/hexagon_vtcm_pool.cc index 8373ef61c9a4..ef3dc592f003 100644 --- a/src/runtime/hexagon/hexagon_vtcm_pool.cc +++ b/src/runtime/hexagon/hexagon_vtcm_pool.cc @@ -37,7 +37,7 @@ HexagonVtcmPool::HexagonVtcmPool() { &total_block_layout, &avail_block_size, &avail_block_layout)); DLOG(INFO) << "HexagonVtcmPool total " << vtcm_device_size_ << " avail " << avail_block_size; - CHECK(avail_block_size >= (1024 * 1024)) << "Less than 1MB VTCM available"; + TVM_FFI_ICHECK(avail_block_size >= (1024 * 1024)) << "Less than 1MB VTCM available"; // allocate nbytes of vtcm on a single page HEXAGON_SAFE_CALL(HAP_compute_res_attr_set_vtcm_param_v2(&res_info, @@ -48,11 +48,13 @@ HexagonVtcmPool::HexagonVtcmPool() { // TODO(HWE): Investigate why a non-zero timeout results in // hanging, both in the simulator and on hardware. context_id_ = HAP_compute_res_acquire(&res_info, /*timeout = */ 0); - CHECK(context_id_) << "HAP_compute_res_acquire failed to acquire requested VTCM resource."; + TVM_FFI_ICHECK(context_id_) + << "HAP_compute_res_acquire failed to acquire requested VTCM resource."; HEXAGON_SAFE_CALL( HAP_compute_res_attr_get_vtcm_ptr_v2(&res_info, &vtcm_data_, &vtcm_allocated_size_)); - CHECK(vtcm_data_ != nullptr) << "HAP_compute_res_acquire returned nullptr when allocating VTCM."; - CHECK(vtcm_allocated_size_ >= avail_block_size) + TVM_FFI_ICHECK(vtcm_data_ != nullptr) + << "HAP_compute_res_acquire returned nullptr when allocating VTCM."; + TVM_FFI_ICHECK(vtcm_allocated_size_ >= avail_block_size) << "HAP_compute_res_acquire failed to allocate minimum amount of VTCM"; free_.emplace_back( std::pair(static_cast(vtcm_data_), vtcm_allocated_size_)); @@ -64,15 +66,15 @@ HexagonVtcmPool::~HexagonVtcmPool() { HEXAGON_SAFE_CALL(HAP_compute_res_release( void* HexagonVtcmPool::Allocate(size_t nbytes) { std::lock_guard lock(mutex_); - CHECK(!free_.empty()) << "No free VTCM"; - CHECK(nbytes >= 0x80) << "Minimum VTCM alloation must be 128 bytes - nbytes " << nbytes; + TVM_FFI_ICHECK(!free_.empty()) << "No free VTCM"; + TVM_FFI_ICHECK(nbytes >= 0x80) << "Minimum VTCM alloation must be 128 bytes - nbytes " << nbytes; // If this is not aligned on a 2k block, allocate from the end to avoid fragmentation if (nbytes & size_t(0x7FF)) { DLOG(INFO) << "VTCM nbytes requested: " << nbytes << " allocate from the end"; auto last_free_entry = free_.end(); last_free_entry--; - CHECK(last_free_entry->second >= nbytes) + TVM_FFI_ICHECK(last_free_entry->second >= nbytes) << "Not enough contiguous VTCM space at the end to allocate"; char* ptr = last_free_entry->first + (last_free_entry->second - nbytes); allocations_.emplace_back(std::pair(ptr, nbytes)); @@ -94,7 +96,8 @@ void* HexagonVtcmPool::Allocate(size_t nbytes) { } } } - CHECK(entry_to_allocate->second >= nbytes) << "Not enough contiguous VTCM space to allocate"; + TVM_FFI_ICHECK(entry_to_allocate->second >= nbytes) + << "Not enough contiguous VTCM space to allocate"; char* ptr = entry_to_allocate->first; allocations_.emplace(allocations_.end(), std::pair(ptr, nbytes)); @@ -114,8 +117,9 @@ void HexagonVtcmPool::Free(void* ptr, size_t nbytes) { auto it = std::find_if(allocations_.begin(), allocations_.end(), [&](auto entry) { return entry.first == ptr_to_free; }); - CHECK(it != allocations_.end()) << "Attempted to free a pointer that had not been allocated"; - CHECK(it->second == nbytes) << "Attempted to free a different size than was allocated"; + TVM_FFI_ICHECK(it != allocations_.end()) + << "Attempted to free a pointer that had not been allocated"; + TVM_FFI_ICHECK(it->second == nbytes) << "Attempted to free a different size than was allocated"; allocations_.erase(it); it = std::lower_bound(free_.begin(), free_.end(), std::pair(ptr_to_free, nbytes), @@ -124,8 +128,9 @@ void HexagonVtcmPool::Free(void* ptr, size_t nbytes) { // Insert an entry at the end it = free_.emplace(it, std::pair(ptr_to_free, nbytes)); } else { - CHECK(ptr_to_free != it->first) << "Attempting to free a pointer that was already free"; - CHECK(ptr_to_free + nbytes <= it->first) + TVM_FFI_ICHECK(ptr_to_free != it->first) + << "Attempting to free a pointer that was already free"; + TVM_FFI_ICHECK(ptr_to_free + nbytes <= it->first) << "free_ is in an inconsistent state, freed block overlaps with next"; if (ptr_to_free + nbytes == it->first) { // Make this entry bigger @@ -141,7 +146,7 @@ void HexagonVtcmPool::Free(void* ptr, size_t nbytes) { if (it != free_.begin()) { auto it_prev = it; it_prev--; - CHECK(it_prev->first + it_prev->second <= ptr_to_free) + TVM_FFI_ICHECK(it_prev->first + it_prev->second <= ptr_to_free) << "free_ is in an inconsistent state, freed block overlaps with previous"; if (it_prev->first + it_prev->second == ptr_to_free) { it_prev->second += it->second; diff --git a/src/runtime/hexagon/hexagon_vtcm_pool.h b/src/runtime/hexagon/hexagon_vtcm_pool.h index d9918a873aa9..0f7153eb54f6 100644 --- a/src/runtime/hexagon/hexagon_vtcm_pool.h +++ b/src/runtime/hexagon/hexagon_vtcm_pool.h @@ -75,9 +75,9 @@ class HexagonVtcmPool { bool IsVtcm(void* ptr, unsigned size) { auto char_ptr = static_cast(ptr); - CHECK(char_ptr != nullptr); + TVM_FFI_ICHECK(char_ptr != nullptr); auto char_vtcm = static_cast(vtcm_data_); - CHECK(vtcm_data_ != nullptr); + TVM_FFI_ICHECK(vtcm_data_ != nullptr); if (char_ptr >= char_vtcm && (char_ptr + size) <= (char_vtcm + vtcm_allocated_size_)) { return true; diff --git a/src/runtime/hexagon/ops/conv2d_fp16_hvx.cc b/src/runtime/hexagon/ops/conv2d_fp16_hvx.cc index 5f171894d9cd..1f60c3638543 100644 --- a/src/runtime/hexagon/ops/conv2d_fp16_hvx.cc +++ b/src/runtime/hexagon/ops/conv2d_fp16_hvx.cc @@ -192,8 +192,8 @@ void conv_layer_fp16_hvx(DLTensor& cr_out, const DLTensor& cr_act, // NOLINT(*) << ", filt_idepth=" << filt_idepth << ", pad_top=" << pad_top << ", pad_left=" << pad_left << "\n"; - ICHECK_LT(pad_top, 8) << "pad_top offset cannot be >= 8"; - ICHECK_LT(pad_left, 4) << "pad_left offset cannot be >= 4"; + TVM_FFI_ICHECK_LT(pad_top, 8) << "pad_top offset cannot be >= 8"; + TVM_FFI_ICHECK_LT(pad_left, 4) << "pad_left offset cannot be >= 4"; int a_height = cr_act.shape[1]; int a_width = cr_act.shape[2]; @@ -217,8 +217,9 @@ void conv_layer_fp16_hvx(DLTensor& cr_out, const DLTensor& cr_act, // NOLINT(*) << o_depth << ", b: " << b_depth << ", out_shape: " << out_height << "x" << out_width << "\n"; - ICHECK_EQ(a_depth, cr_filt.shape[2]) << "input depth should match weights input channels"; - ICHECK_EQ(o_depth, cr_filt.shape[3]) << "output depth should match the weights output channel"; + TVM_FFI_ICHECK_EQ(a_depth, cr_filt.shape[2]) << "input depth should match weights input channels"; + TVM_FFI_ICHECK_EQ(o_depth, cr_filt.shape[3]) + << "output depth should match the weights output channel"; int rd = round_down(filt_width, 4); int wgt_chunk_thin_width = filt_width - rd; @@ -404,20 +405,20 @@ void conv_layer_fp16_hvx(DLTensor& cr_out, const DLTensor& cr_act, // NOLINT(*) int conv2d_packed_fp16(void*, TVMFFIAny* args, int num_args, TVMFFIAny* out_val) { namespace conv_utils = tvm::runtime::hexagon::conv_utils; - ICHECK_EQ(num_args, 7) << "Unexpected number of arguments"; - ICHECK_EQ(args[0].type_index, kTVMFFIDLTensorPtr) + TVM_FFI_ICHECK_EQ(num_args, 7) << "Unexpected number of arguments"; + TVM_FFI_ICHECK_EQ(args[0].type_index, kTVMFFIDLTensorPtr) << "First argument is expected to be the input tensor"; // Input activations - ICHECK_EQ(args[1].type_index, kTVMFFIDLTensorPtr) + TVM_FFI_ICHECK_EQ(args[1].type_index, kTVMFFIDLTensorPtr) << "Second argument is expected to be the weights tensor"; // Weights - ICHECK_EQ(args[2].type_index, kTVMFFIInt) + TVM_FFI_ICHECK_EQ(args[2].type_index, kTVMFFIInt) << "Third argument is expected to be the pad_top offset"; // pad_top offset - ICHECK_EQ(args[3].type_index, kTVMFFIInt) + TVM_FFI_ICHECK_EQ(args[3].type_index, kTVMFFIInt) << "Fourth argument is expected to be the pad_left offset"; // pad_left offset - ICHECK_EQ(args[4].type_index, kTVMFFIInt) + TVM_FFI_ICHECK_EQ(args[4].type_index, kTVMFFIInt) << "Fifth argument is expected to be the stride_h"; // stride_h - ICHECK_EQ(args[5].type_index, kTVMFFIInt) + TVM_FFI_ICHECK_EQ(args[5].type_index, kTVMFFIInt) << "Sixth argument is expected to be the stride_w"; // stride_w - ICHECK_EQ(args[6].type_index, kTVMFFIDLTensorPtr) + TVM_FFI_ICHECK_EQ(args[6].type_index, kTVMFFIDLTensorPtr) << "Seventh argument is expected to be the output tensor"; // output auto* act_flat = static_cast(args[0].v_ptr); @@ -425,10 +426,10 @@ int conv2d_packed_fp16(void*, TVMFFIAny* args, int num_args, TVMFFIAny* out_val) auto* out_flat = static_cast(args[6].v_ptr); // Temporary assertion until multiple batches are supported - ICHECK_EQ(act_flat->shape[0], 1) << "Input batch size more than 1 is not supported yet"; + TVM_FFI_ICHECK_EQ(act_flat->shape[0], 1) << "Input batch size more than 1 is not supported yet"; // Temporary assertion until multiple batches are supported - ICHECK_EQ(out_flat->shape[0], 1) << "Output batch size more than 1 is not supported yet"; + TVM_FFI_ICHECK_EQ(out_flat->shape[0], 1) << "Output batch size more than 1 is not supported yet"; int pad_top = args[2].v_int64; int pad_left = args[3].v_int64; @@ -442,16 +443,16 @@ int conv2d_packed_fp16(void*, TVMFFIAny* args, int num_args, TVMFFIAny* out_val) << ", pad_left=" << pad_left; auto* device_api = tvm::runtime::DeviceAPI::Get(conv_utils::hexagon_device, false); - ICHECK(device_api != nullptr); + TVM_FFI_ICHECK(device_api != nullptr); tvm::ffi::String vtcm_scope = "global.vtcm"; auto act_vtcm = conv_utils::prepare_nhwc(device_api, act_flat, /*copy_data=*/true); - ICHECK_NE(wgt_flat->shape[0], 0) << "Weights height should not be zero"; - ICHECK_NE(wgt_flat->shape[1], 0) << "Weights width should not be zero"; - ICHECK_NE(wgt_flat->shape[2], 0) << "Weights input channels should not be zero"; - ICHECK_NE(wgt_flat->shape[3], 0) << "Weights output channels should not be zero"; + TVM_FFI_ICHECK_NE(wgt_flat->shape[0], 0) << "Weights height should not be zero"; + TVM_FFI_ICHECK_NE(wgt_flat->shape[1], 0) << "Weights width should not be zero"; + TVM_FFI_ICHECK_NE(wgt_flat->shape[2], 0) << "Weights input channels should not be zero"; + TVM_FFI_ICHECK_NE(wgt_flat->shape[3], 0) << "Weights output channels should not be zero"; int num_wgt_chunks = conv_utils::calculate_num_weight_chunks( wgt_flat->shape, /* chunk_height */ 8, /* chunk_width */ 4, /* chunk_in_channel */ 32, /* chunk_out_channel */ 32); diff --git a/src/runtime/hexagon/ops/conv2d_quant_hvx.cc b/src/runtime/hexagon/ops/conv2d_quant_hvx.cc index 30cba60cf1a8..bc6f2f928554 100644 --- a/src/runtime/hexagon/ops/conv2d_quant_hvx.cc +++ b/src/runtime/hexagon/ops/conv2d_quant_hvx.cc @@ -111,8 +111,9 @@ void conv_layer_int8_hvx_whole(DLTensor& cr_out, const DLTensor& cr_act, // NOL HVX_Vector wgt_zp_vec = Q6_Vb_vsplat_R(wgt_zp_i8); HVX_VectorPair wgt_zp_vec_pair = Q6_Wh_vsxt_Vb(wgt_zp_vec); - ICHECK_EQ(a_depth, cr_filt.shape[2]) << "input depth should match weights input channels"; - ICHECK_EQ(o_depth, cr_filt.shape[3]) << "output depth should match the weights output channel"; + TVM_FFI_ICHECK_EQ(a_depth, cr_filt.shape[2]) << "input depth should match weights input channels"; + TVM_FFI_ICHECK_EQ(o_depth, cr_filt.shape[3]) + << "output depth should match the weights output channel"; uint32_t scale_u = static_cast(fixed_final_scale); HVX_Vector scale_vec = Q6_V_vsplat_R(scale_u); @@ -231,31 +232,32 @@ void conv_layer_int8_hvx_whole(DLTensor& cr_out, const DLTensor& cr_act, // NOL int conv2d_packed_quant(void*, TVMFFIAny* args, int num_args, TVMFFIAny* out_val) { namespace conv_utils = tvm::runtime::hexagon::conv_utils; - ICHECK_EQ(num_args, 13) << "Unexpected number of arguments"; - ICHECK_EQ(args[0].type_index, kTVMFFIDLTensorPtr) + TVM_FFI_ICHECK_EQ(num_args, 13) << "Unexpected number of arguments"; + TVM_FFI_ICHECK_EQ(args[0].type_index, kTVMFFIDLTensorPtr) << "First argument is expected to be the input tensor"; // Input activations - ICHECK_EQ(args[1].type_index, kTVMFFIDLTensorPtr) + TVM_FFI_ICHECK_EQ(args[1].type_index, kTVMFFIDLTensorPtr) << "Second argument is expected to be the weights tensor"; // Weights - ICHECK_EQ(args[2].type_index, kTVMFFIFloat) + TVM_FFI_ICHECK_EQ(args[2].type_index, kTVMFFIFloat) << "Third argument is expected to be the activation scale"; - ICHECK_EQ(args[3].type_index, kTVMFFIInt) + TVM_FFI_ICHECK_EQ(args[3].type_index, kTVMFFIInt) << "Fourth argument is expected to be the activation zero point"; - ICHECK_EQ(args[4].type_index, kTVMFFIFloat) + TVM_FFI_ICHECK_EQ(args[4].type_index, kTVMFFIFloat) << "Fifth argument is expected to be the weight scale"; - ICHECK_EQ(args[5].type_index, kTVMFFIInt) + TVM_FFI_ICHECK_EQ(args[5].type_index, kTVMFFIInt) << "Sixth argument is expected to be the weight zero point"; - ICHECK_EQ(args[6].type_index, kTVMFFIFloat) + TVM_FFI_ICHECK_EQ(args[6].type_index, kTVMFFIFloat) << "Seventh argument is expected to be the output scale"; - ICHECK_EQ(args[7].type_index, kTVMFFIInt) + TVM_FFI_ICHECK_EQ(args[7].type_index, kTVMFFIInt) << "Eigth argument is expected to be the output zero point"; - ICHECK_EQ(args[8].type_index, kTVMFFIInt) + TVM_FFI_ICHECK_EQ(args[8].type_index, kTVMFFIInt) << "Nineth argument is expected to be the stride_h"; // stride_h - ICHECK_EQ(args[9].type_index, kTVMFFIInt) + TVM_FFI_ICHECK_EQ(args[9].type_index, kTVMFFIInt) << "Tenth argument is expected to be the stride_w"; // stride_w - ICHECK_EQ(args[10].type_index, kTVMFFIInt) + TVM_FFI_ICHECK_EQ(args[10].type_index, kTVMFFIInt) << "Eleventh argument is expected to be fixed final scale"; - ICHECK_EQ(args[11].type_index, kTVMFFIInt) << "Twelfth argument is expected to be scale factor"; - ICHECK_EQ(args[12].type_index, kTVMFFIDLTensorPtr) + TVM_FFI_ICHECK_EQ(args[11].type_index, kTVMFFIInt) + << "Twelfth argument is expected to be scale factor"; + TVM_FFI_ICHECK_EQ(args[12].type_index, kTVMFFIDLTensorPtr) << "Thirteenth argument is expected to be the output tensor"; // output auto* act_flat = static_cast(args[0].v_ptr); @@ -263,10 +265,10 @@ int conv2d_packed_quant(void*, TVMFFIAny* args, int num_args, TVMFFIAny* out_val auto* out_flat = static_cast(args[12].v_ptr); // Temporary assertion until multiple batches are supported - ICHECK_EQ(act_flat->shape[0], 1) << "Input batch size more than 1 is not supported yet"; + TVM_FFI_ICHECK_EQ(act_flat->shape[0], 1) << "Input batch size more than 1 is not supported yet"; // Temporary assertion until multiple batches are supported - ICHECK_EQ(out_flat->shape[0], 1) << "Output batch size more than 1 is not supported yet"; + TVM_FFI_ICHECK_EQ(out_flat->shape[0], 1) << "Output batch size more than 1 is not supported yet"; float act_scale = args[2].v_float64; int act_zp = args[3].v_int64; @@ -289,7 +291,7 @@ int conv2d_packed_quant(void*, TVMFFIAny* args, int num_args, TVMFFIAny* out_val LOG_INFO << "fixed_final_scale: " << fixed_final_scale << ", scale_factor: " << scale_factor; auto* device_api = tvm::runtime::DeviceAPI::Get(conv_utils::hexagon_device, false); - ICHECK(device_api != nullptr); + TVM_FFI_ICHECK(device_api != nullptr); tvm::ffi::String vtcm_scope = "global.vtcm"; auto act_vtcm = diff --git a/src/runtime/hexagon/qhl/qhl_wrapper.cc b/src/runtime/hexagon/qhl/qhl_wrapper.cc index df188c8907e5..e1515ecc7e08 100644 --- a/src/runtime/hexagon/qhl/qhl_wrapper.cc +++ b/src/runtime/hexagon/qhl/qhl_wrapper.cc @@ -61,7 +61,9 @@ template HVX_Vector wrapper_api(HVX_Vector input, qhlFptr qhl_api, const char* qhl_api_name) { HVX_Vector output; int32_t res = qhl_api(reinterpret_cast(&input), reinterpret_cast(&output), 64); - if (res != 0) LOG(FATAL) << "Error. Failed execution of " << qhl_api_name << " Error=" << res; + if (res != 0) + TVM_FFI_THROW(InternalError) << "Error. Failed execution of " << qhl_api_name + << " Error=" << res; return output; } @@ -70,7 +72,9 @@ HVX_Vector wrapper_api(HVX_Vector ip1, HVX_Vector ip2, qhlFptr2 qhl_api, const c HVX_Vector output; int32_t res = qhl_api(reinterpret_cast(&ip1), reinterpret_cast(&ip2), reinterpret_cast(&output), 64); - if (res != 0) LOG(FATAL) << "Error. Failed execution of " << qhl_api_name << "Error=" << res; + if (res != 0) + TVM_FFI_THROW(InternalError) << "Error. Failed execution of " << qhl_api_name + << "Error=" << res; return output; } diff --git a/src/runtime/hexagon/ring_buffer.h b/src/runtime/hexagon/ring_buffer.h index 91adad6a65e5..4d5df5a9ca5d 100644 --- a/src/runtime/hexagon/ring_buffer.h +++ b/src/runtime/hexagon/ring_buffer.h @@ -58,11 +58,11 @@ class RingBuffer { */ RingBuffer(uint32_t ring_buff_size, std::function in_flight) : ring_buff_size_(ring_buff_size), in_flight_(in_flight) { - CHECK_NE(ring_buff_size, 0); + TVM_FFI_ICHECK_NE(ring_buff_size, 0); int ret = posix_memalign(reinterpret_cast(&ring_buff_ptr_), sizeof(T), sizeof(T) * ring_buff_size_); - CHECK_EQ(ret, 0); - CHECK_NE(ring_buff_ptr_, nullptr); + TVM_FFI_ICHECK_EQ(ret, 0); + TVM_FFI_ICHECK_NE(ring_buff_ptr_, nullptr); } ~RingBuffer() { free(ring_buff_ptr_); } @@ -103,7 +103,7 @@ class QueuedRingBuffer : RingBuffer { //! \brief Returns pointer to next T; add the queue ID for tracking T* Next(uint32_t queue_id) { - CHECK_LT(queue_id, max_queues_); + TVM_FFI_ICHECK_LT(queue_id, max_queues_); queue_ids_.push_back(queue_id); queue_descriptor* d = &queue_descriptors_[queue_id]; if (d->group_started) { @@ -119,9 +119,9 @@ class QueuedRingBuffer : RingBuffer { //! \brief Returns the number of groups of Ts in flight for a given queue ID uint32_t InFlight(uint32_t queue_id) { - CHECK_LT(queue_id, max_queues_); + TVM_FFI_ICHECK_LT(queue_id, max_queues_); queue_descriptor* d = &queue_descriptors_[queue_id]; - CHECK(!d->group_started); + TVM_FFI_ICHECK(!d->group_started); uint32_t in_flight = 0; // look at the queue IDs for the RingBuffer entries in flight @@ -144,9 +144,9 @@ class QueuedRingBuffer : RingBuffer { //! \brief Start a group of Ts, if not called the deafault group size is one void StartGroup(uint32_t queue_id) { - CHECK_LT(queue_id, max_queues_); + TVM_FFI_ICHECK_LT(queue_id, max_queues_); queue_descriptor* d = &queue_descriptors_[queue_id]; - CHECK(!d->group_started); + TVM_FFI_ICHECK(!d->group_started); // start group d->group_started = true; @@ -155,10 +155,10 @@ class QueuedRingBuffer : RingBuffer { //! \brief End a group of Ts void EndGroup(uint32_t queue_id) { - CHECK_LT(queue_id, max_queues_); + TVM_FFI_ICHECK_LT(queue_id, max_queues_); queue_descriptor* d = &queue_descriptors_[queue_id]; - CHECK(d->group_started); - CHECK(d->pending_in_group); + TVM_FFI_ICHECK(d->group_started); + TVM_FFI_ICHECK(d->pending_in_group); // create group if (d->pending_in_group) { diff --git a/src/runtime/hexagon/rpc/android/session.cc b/src/runtime/hexagon/rpc/android/session.cc index 55eee5df27f0..6052225e68f4 100644 --- a/src/runtime/hexagon/rpc/android/session.cc +++ b/src/runtime/hexagon/rpc/android/session.cc @@ -54,27 +54,28 @@ class HexagonTransportChannel : public RPCChannel { set_remote_stack_size(remote_stack_size_bytes); AEEResult rc = hexagon_rpc_open(uri.c_str(), &_handle); - ICHECK(rc == AEE_SUCCESS) << "hexagon_rpc_open failed. URI: " << uri.c_str(); + TVM_FFI_ICHECK(rc == AEE_SUCCESS) << "hexagon_rpc_open failed. URI: " << uri.c_str(); rc = hexagon_rpc_init(_handle, receive_buf_size_bytes); - ICHECK(rc == AEE_SUCCESS) << "hexagon_rpc_set_receive_buf_size failed. receive_buf_size_bytes: " - << receive_buf_size_bytes; + TVM_FFI_ICHECK(rc == AEE_SUCCESS) + << "hexagon_rpc_set_receive_buf_size failed. receive_buf_size_bytes: " + << receive_buf_size_bytes; } size_t Send(const void* data, size_t size) override { - ICHECK(_handle != AEE_EUNKNOWN) << "RPC handle is not initialized."; + TVM_FFI_ICHECK(_handle != AEE_EUNKNOWN) << "RPC handle is not initialized."; AEEResult rc = hexagon_rpc_send(_handle, static_cast(data), static_cast(size)); - ICHECK(rc == AEE_SUCCESS) << "hexagon_rpc_send failed: " << rc; + TVM_FFI_ICHECK(rc == AEE_SUCCESS) << "hexagon_rpc_send failed: " << rc; return size; } size_t Recv(void* data, size_t size) override { - ICHECK(_handle != AEE_EUNKNOWN) << "RPC handle is not initialized."; + TVM_FFI_ICHECK(_handle != AEE_EUNKNOWN) << "RPC handle is not initialized."; int64_t written_size = 0; AEEResult rc = hexagon_rpc_receive(_handle, static_cast(data), static_cast(size), &written_size); - ICHECK(rc == AEE_SUCCESS) << "hexagon_rpc_receive failed: " << rc; + TVM_FFI_ICHECK(rc == AEE_SUCCESS) << "hexagon_rpc_receive failed: " << rc; return static_cast(written_size); } @@ -114,7 +115,7 @@ TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def_packed( "tvm.contrib.hexagon.create_hexagon_session", [](ffi::PackedArgs args, ffi::Any* rv) { - ICHECK(args.size() >= 4) << args.size() << " is less than 4"; + TVM_FFI_ICHECK(args.size() >= 4) << args.size() << " is less than 4"; auto session_name = args[0].cast(); int remote_stack_size_bytes = args[1].cast(); diff --git a/src/runtime/hexagon/rpc/hexagon/rpc_server.cc b/src/runtime/hexagon/rpc/hexagon/rpc_server.cc index d9c2e647aea2..c3bcf27a40fd 100644 --- a/src/runtime/hexagon/rpc/hexagon/rpc_server.cc +++ b/src/runtime/hexagon/rpc/hexagon/rpc_server.cc @@ -234,7 +234,7 @@ tvm::runtime::hexagon::HexagonRPCServer* get_hexagon_rpc_server( if (g_hexagon_rpc_server) { return g_hexagon_rpc_server; } - CHECK_GT(rpc_receive_buff_size_bytes, 0) << "RPC receive buffer size is not valid."; + TVM_FFI_ICHECK_GT(rpc_receive_buff_size_bytes, 0) << "RPC receive buffer size is not valid."; static tvm::runtime::hexagon::HexagonRPCServer hexagon_rpc_server( new uint8_t[rpc_receive_buff_size_bytes], rpc_receive_buff_size_bytes); g_hexagon_rpc_server = &hexagon_rpc_server; @@ -353,7 +353,7 @@ TVM_FFI_STATIC_INIT_BLOCK() { void SaveBinaryToFile(const std::string& file_name, const std::string& data) { std::ofstream fs(file_name, std::ios::out | std::ios::binary); - ICHECK(!fs.fail()) << "Cannot open " << file_name; + TVM_FFI_ICHECK(!fs.fail()) << "Cannot open " << file_name; fs.write(&data[0], data.length()); } diff --git a/src/runtime/hexagon/rpc/simulator/rpc_server.cc b/src/runtime/hexagon/rpc/simulator/rpc_server.cc index c3cec3039221..fba9253e53f2 100644 --- a/src/runtime/hexagon/rpc/simulator/rpc_server.cc +++ b/src/runtime/hexagon/rpc/simulator/rpc_server.cc @@ -295,14 +295,14 @@ int main(int argc, char* argv[]) { // Load C++RT and ourselves as "global" to make all the symbols defined // there be visible to any subsequent libraries loaded via dlopen. void* cxx_abi = dlopen("libc++abi.so", RTLD_GLOBAL); - ICHECK(cxx_abi != nullptr); + TVM_FFI_ICHECK(cxx_abi != nullptr); void* cxx = dlopen("libc++.so", RTLD_GLOBAL); - ICHECK(cxx != nullptr); + TVM_FFI_ICHECK(cxx != nullptr); void* self = dlopen(argv[0], RTLD_GLOBAL); - ICHECK(self != nullptr); + TVM_FFI_ICHECK(self != nullptr); const auto api = tvm::ffi::Function::GetGlobal("device_api.hexagon"); - ICHECK(api.has_value()); + TVM_FFI_ICHECK(api.has_value()); tvm::ffi::Function::SetGlobal("device_api.cpu", *api, true); tvm::runtime::hexagon::SimulatorRPCServer server; @@ -357,7 +357,7 @@ TVM_FFI_STATIC_INIT_BLOCK() { void SaveBinaryToFile(const std::string& file_name, const std::string& data) { std::ofstream fs(file_name, std::ios::out | std::ios::binary); - ICHECK(!fs.fail()) << "Cannot open " << file_name; + TVM_FFI_ICHECK(!fs.fail()) << "Cannot open " << file_name; fs.write(&data[0], data.length()); } diff --git a/src/runtime/hexagon/rpc/simulator/session.cc b/src/runtime/hexagon/rpc/simulator/session.cc index d7a9ade7234a..0864796a9ad9 100644 --- a/src/runtime/hexagon/rpc/simulator/session.cc +++ b/src/runtime/hexagon/rpc/simulator/session.cc @@ -43,7 +43,7 @@ #define CHECKED_CALL(func, ...) \ do { \ HEXAPI_Status s = sim_->func(__VA_ARGS__); \ - ICHECK_EQ(s, HEX_STAT_SUCCESS) \ + TVM_FFI_ICHECK_EQ(s, HEX_STAT_SUCCESS) \ << self_name_ << ": " #func " failed with code " << Status_{s}.str(); \ } while (false) @@ -65,7 +65,7 @@ class StringSwitch { if (f != map.end()) { return f->second; } - ICHECK(static_cast(def_val)) << "default value not set"; + TVM_FFI_ICHECK(static_cast(def_val)) << "default value not set"; return *def_val; } StringSwitch& Case(const std::string& key, T val) { @@ -73,7 +73,7 @@ class StringSwitch { return *this; } StringSwitch& Default(T val) { - ICHECK(!static_cast(def_val)) << "default value already set"; + TVM_FFI_ICHECK(!static_cast(def_val)) << "default value already set"; def_val = val; return *this; } @@ -468,7 +468,7 @@ std::string SimulatorRPCChannel::Cpu_::str() const { return default_cpu_; } -// LOG(FATAL) always throws an exception or terminates the +// TVM_FFI_THROW(InternalError) always throws an exception or terminates the // process, but the compiler doesn't know that. #if (__GNUC__) #pragma GCC diagnostic push @@ -497,7 +497,7 @@ std::string SimulatorRPCChannel::Message_::str() const { case Message::kSendEnd: return "kSendEnd"; default: - LOG(FATAL) << "Internal error: Unrecognized code value: " << msg.code; + TVM_FFI_THROW(InternalError) << "Internal error: Unrecognized code value: " << msg.code; break; } } @@ -525,7 +525,8 @@ SimulatorRPCChannel::SDKInfo_::SDKInfo_(const std::string& sdk_root, const std:: std::vector dir_names; DIR* dir = opendir((root + "/libs/run_main_on_hexagon/ship").c_str()); - ICHECK(dir != nullptr) << "Cannot read directory " << root + "/libs/run_main_on_hexagon/ship"; + TVM_FFI_ICHECK(dir != nullptr) << "Cannot read directory " + << root + "/libs/run_main_on_hexagon/ship"; while (dirent* d = readdir(dir)) { if (d->d_type != DT_DIR) continue; @@ -537,7 +538,7 @@ SimulatorRPCChannel::SDKInfo_::SDKInfo_(const std::string& sdk_root, const std:: } } closedir(dir); - ICHECK(!dir_names.empty()); + TVM_FFI_ICHECK(!dir_names.empty()); auto max_it = std::max_element(dir_names.begin(), dir_names.end()); runmain = root + "/libs/run_main_on_hexagon/ship/" + *max_it + "/run_main_on_hexagon_sim"; @@ -553,8 +554,8 @@ HEX_8u_t SimulatorRPCChannel::PassVirtAddrCallback(void* handle, int threadno, H LOG(INFO) << "dispatch:" << reinterpret_cast(rpc->dispatch_v_) << ", message buffer:" << reinterpret_cast(rpc->message_buffer_v_); HEXAPI_Status s = rpc->sim_->SetBreakpoint(rpc->dispatch_v_); - ICHECK_EQ(s, HEX_STAT_SUCCESS) << self_name_ << ": SetBreakpoint failed with code " - << Status_{s}.str(); + TVM_FFI_ICHECK_EQ(s, HEX_STAT_SUCCESS) + << self_name_ << ": SetBreakpoint failed with code " << Status_{s}.str(); return RssV; } @@ -587,9 +588,9 @@ std::optional SimulatorRPCChannel::GetCPU(const detail::MaybeString& SimulatorRPCChannel::SimulatorRPCChannel(int stack_size, std::string args) { const char* sdk_root_env = std::getenv("HEXAGON_SDK_ROOT"); - ICHECK(sdk_root_env != nullptr) << "Please set HEXAGON_SDK_ROOT"; + TVM_FFI_ICHECK(sdk_root_env != nullptr) << "Please set HEXAGON_SDK_ROOT"; const char* toolchain_env = std::getenv("HEXAGON_TOOLCHAIN"); - ICHECK(toolchain_env != nullptr) << "Please set HEXAGON_TOOLCHAIN"; + TVM_FFI_ICHECK(toolchain_env != nullptr) << "Please set HEXAGON_TOOLCHAIN"; std::string sdk_root(sdk_root_env); std::string toolchain(toolchain_env); @@ -605,7 +606,7 @@ SimulatorRPCChannel::SimulatorRPCChannel(int stack_size, std::string args) { LOG(INFO) << "CPU not given, defaulting to " << default_cpu_; maybe_cpu = GetCPU(std::string(default_cpu_)); } else { - LOG(FATAL) << "Invalid CPU name " << *target_str; + TVM_FFI_THROW(InternalError) << "Invalid CPU name " << *target_str; } } cpu_ = Cpu_{*maybe_cpu}.str(); @@ -614,18 +615,18 @@ SimulatorRPCChannel::SimulatorRPCChannel(int stack_size, std::string args) { // Prepare the osam.cfg file. int fd_osam = mkstemps(osam_file_, suffix_len_); - ICHECK_GE(fd_osam, 0); + TVM_FFI_ICHECK_GE(fd_osam, 0); std::string osam_str = sdk.qurt_root + "/debugger/lnx64/qurt_model.so"; - ICHECK_EQ(write(fd_osam, osam_str.c_str(), osam_str.size()), osam_str.size()); + TVM_FFI_ICHECK_EQ(write(fd_osam, osam_str.c_str(), osam_str.size()), osam_str.size()); close(fd_osam); // Prepare the q6ss.cfg file. int fd_cosim = mkstemps(cosim_file_, suffix_len_); - ICHECK_GE(fd_cosim, 0); + TVM_FFI_ICHECK_GE(fd_cosim, 0); std::string cosim_str = toolchain + "/lib/iss/qtimer.so --csr_base=0xFC900000 --irq_p=1 --freq=19200000 --cnttid=1\n" + toolchain + "/lib/iss/l2vic.so 32 0xFC910000"; - ICHECK_EQ(write(fd_cosim, cosim_str.c_str(), cosim_str.size()), cosim_str.size()); + TVM_FFI_ICHECK_EQ(write(fd_cosim, cosim_str.c_str(), cosim_str.size()), cosim_str.size()); close(fd_cosim); CHECKED_CALL(ConfigureL2tcmBase, 0xD800); @@ -650,7 +651,8 @@ SimulatorRPCChannel::SimulatorRPCChannel(int stack_size, std::string args) { HEX_4u_t result; HEXAPI_CoreState core = sim_->Run(&result); if (core != HEX_CORE_BREAKPOINT) { - LOG(FATAL) << self_name_ << ": Run not stopped on breakpoint, code=" << Core_{core}.str(); + TVM_FFI_THROW(InternalError) << self_name_ + << ": Run not stopped on breakpoint, code=" << Core_{core}.str(); } // At this point the simulator has executed the executable's initialization @@ -669,40 +671,40 @@ SimulatorRPCChannel::~SimulatorRPCChannel() { HEX_4u_t result; HEXAPI_CoreState core = sim_->Run(&result); - ICHECK_EQ(core, HEX_CORE_FINISHED); + TVM_FFI_ICHECK_EQ(core, HEX_CORE_FINISHED); unlink(osam_file_); unlink(cosim_file_); } size_t SimulatorRPCChannel::Send(const void* data, size_t size) { - ICHECK(size <= std::numeric_limits::max()); + TVM_FFI_ICHECK(size <= std::numeric_limits::max()); Message reply_start = SendMsg(Message::kReceiveStart, static_cast(size), Message::null_va); - ICHECK_EQ(reply_start.code, Message::kAck); - ICHECK_GE(reply_start.len, size); - ICHECK_NE(reply_start.va, Message::null_va); + TVM_FFI_ICHECK_EQ(reply_start.code, Message::kAck); + TVM_FFI_ICHECK_GE(reply_start.len, size); + TVM_FFI_ICHECK_NE(reply_start.va, Message::null_va); WriteToProcess(reply_start.va, data, size); Message reply_end = SendMsg(Message::kReceiveEnd, static_cast(size), reply_start.va); - ICHECK_EQ(reply_end.code, Message::kAck); + TVM_FFI_ICHECK_EQ(reply_end.code, Message::kAck); return size; } size_t SimulatorRPCChannel::Recv(void* data, size_t size) { - ICHECK(size <= std::numeric_limits::max()); + TVM_FFI_ICHECK(size <= std::numeric_limits::max()); Message reply_start = SendMsg(Message::kSendStart, static_cast(size), Message::null_va); - ICHECK_EQ(reply_start.code, Message::kAck); - ICHECK_GE(reply_start.len, size); - ICHECK_NE(reply_start.va, Message::null_va); + TVM_FFI_ICHECK_EQ(reply_start.code, Message::kAck); + TVM_FFI_ICHECK_GE(reply_start.len, size); + TVM_FFI_ICHECK_NE(reply_start.va, Message::null_va); ReadFromProcess(data, reply_start.va, size); Message reply_end = SendMsg(Message::kSendEnd, static_cast(size), reply_start.va); - ICHECK_EQ(reply_end.code, Message::kAck); + TVM_FFI_ICHECK_EQ(reply_end.code, Message::kAck); return size; } @@ -713,7 +715,7 @@ Message SimulatorRPCChannel::SendMsg(Message msg) { core = sim_->Run(&result); Core_ core_ = {core}; - ICHECK_EQ(core, HEX_CORE_BREAKPOINT) + TVM_FFI_ICHECK_EQ(core, HEX_CORE_BREAKPOINT) << "Expecting HEX_CORE_BREAKPOINT, received: " << core_.str(); }; @@ -782,17 +784,17 @@ bool SimulatorRPCChannel::Configure(string_list& opts) { std::string key = *detail::pop_front(opts); auto f = opt_map_.find(key); if (f == opt_map_.end()) { - LOG(FATAL) << "Unrecognized simulator option: " << key; + TVM_FFI_THROW(InternalError) << "Unrecognized simulator option: " << key; // unreachable } - ICHECK((this->*f->second)(opts)) << "error handling option: " << key; + TVM_FFI_ICHECK((this->*f->second)(opts)) << "error handling option: " << key; } // Check AHB. if (ahb_.first.has_value() && ahb_.second.has_value()) { CHECKED_CALL(ConfigureAHB, *ahb_.first, *ahb_.second); } else { - ICHECK(!ahb_.first.has_value() && !ahb_.second.has_value()) + TVM_FFI_ICHECK(!ahb_.first.has_value() && !ahb_.second.has_value()) << self_name_ << ": please specify both low and high addresses for AHB"; } @@ -800,7 +802,7 @@ bool SimulatorRPCChannel::Configure(string_list& opts) { if (axi2_.first.has_value() && axi2_.second.has_value()) { CHECKED_CALL(ConfigureAXI2, *axi2_.first, *axi2_.second); } else { - ICHECK(!axi2_.first.has_value() && !axi2_.second.has_value()) + TVM_FFI_ICHECK(!axi2_.first.has_value() && !axi2_.second.has_value()) << self_name_ << ": please specify both low and high addresses for AXI2"; } @@ -826,7 +828,7 @@ bool SimulatorRPCChannel::HandleAHBBusRatio(string_list& rest) { bool SimulatorRPCChannel::HandleAHBHighAddr(string_list& rest) { auto addr = detail::to_uint(detail::pop_front(rest)); - ICHECK(addr) << self_name_ << ": invalid value for AHB high adddress"; + TVM_FFI_ICHECK(addr) << self_name_ << ": invalid value for AHB high adddress"; if (addr) { ahb_.second = *addr; } @@ -835,7 +837,7 @@ bool SimulatorRPCChannel::HandleAHBHighAddr(string_list& rest) { bool SimulatorRPCChannel::HandleAHBLowAddr(string_list& rest) { auto addr = detail::to_uint(detail::pop_front(rest)); - ICHECK(addr) << self_name_ << ": invalid value for AHB low adddress"; + TVM_FFI_ICHECK(addr) << self_name_ << ": invalid value for AHB low adddress"; if (addr) { ahb_.first = *addr; } @@ -861,7 +863,7 @@ bool SimulatorRPCChannel::HandleAXI2BusRatio(string_list& rest) { bool SimulatorRPCChannel::HandleAXI2HighAddr(string_list& rest) { auto addr = detail::to_uint(detail::pop_front(rest)); - ICHECK(addr) << self_name_ << ": invalid value for AXI2 high adddress"; + TVM_FFI_ICHECK(addr) << self_name_ << ": invalid value for AXI2 high adddress"; if (addr) { axi2_.second = *addr; } @@ -870,7 +872,7 @@ bool SimulatorRPCChannel::HandleAXI2HighAddr(string_list& rest) { bool SimulatorRPCChannel::HandleAXI2LowAddr(string_list& rest) { auto addr = detail::to_uint(detail::pop_front(rest)); - ICHECK(addr) << self_name_ << ": invalid value for AXI2 low adddress"; + TVM_FFI_ICHECK(addr) << self_name_ << ": invalid value for AXI2 low adddress"; if (addr) { axi2_.first = *addr; } @@ -1120,8 +1122,8 @@ bool SimulatorRPCChannel::HandleQuiet(string_list& rest) { bool SimulatorRPCChannel::HandleReconnect(string_list& rest) { if (!debug_port_) { - LOG(FATAL) << "Reconnect error: --reconnect must be specified " - "AFTER --gdbserv "; + TVM_FFI_THROW(InternalError) << "Reconnect error: --reconnect must be specified " + "AFTER --gdbserv "; } CHECKED_CALL(ConfigureRemoteDebug, *debug_port_, true); return true; @@ -1374,7 +1376,7 @@ TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def_packed( "tvm.contrib.hexagon.create_hexagon_session", [](ffi::PackedArgs args, ffi::Any* rv) { - ICHECK(args.size() >= 4) << args.size() << " is less than 4"; + TVM_FFI_ICHECK(args.size() >= 4) << args.size() << " is less than 4"; auto session_name = args[0].cast(); int stack_size = args[1].cast(); diff --git a/src/runtime/logging.cc b/src/runtime/logging.cc index ef902f22fd5b..159e4b4a4cfa 100644 --- a/src/runtime/logging.cc +++ b/src/runtime/logging.cc @@ -93,28 +93,29 @@ TvmLogDebugSettings TvmLogDebugSettings::ParseSpec(const char* opt_spec) { break; } if (name.empty()) { - LOG(FATAL) << "TVM_LOG_DEBUG ill-formed at position " << tell_pos(name) << ": empty filename"; + TVM_FFI_THROW(InternalError) + << "TVM_LOG_DEBUG ill-formed at position " << tell_pos(name) << ": empty filename"; } name = FileToVLogMapKey(name); std::string level; if (!std::getline(spec_stream, level, ',')) { - LOG(FATAL) << "TVM_LOG_DEBUG ill-formed at position " << tell_pos(level) - << ": expecting \"=\" after \"" << name << "\""; + TVM_FFI_THROW(InternalError) << "TVM_LOG_DEBUG ill-formed at position " << tell_pos(level) + << ": expecting \"=\" after \"" << name << "\""; return settings; } if (level.empty()) { - LOG(FATAL) << "TVM_LOG_DEBUG ill-formed at position " << tell_pos(level) - << ": empty level after \"" << name << "\""; + TVM_FFI_THROW(InternalError) << "TVM_LOG_DEBUG ill-formed at position " << tell_pos(level) + << ": empty level after \"" << name << "\""; return settings; } // Parse level, default to 0 if ill-formed which we don't detect. char* end_of_level = nullptr; int level_val = static_cast(strtol(level.c_str(), &end_of_level, 10)); if (end_of_level != level.c_str() + level.size()) { - LOG(FATAL) << "TVM_LOG_DEBUG ill-formed at position " << tell_pos(level) - << ": invalid level: \"" << level << "\""; + TVM_FFI_THROW(InternalError) << "TVM_LOG_DEBUG ill-formed at position " << tell_pos(level) + << ": invalid level: \"" << level << "\""; return settings; } LOG(INFO) << "TVM_LOG_DEBUG enables VLOG statements in '" << name << "' up to level " << level; @@ -145,7 +146,7 @@ LogFatal::Entry& LogFatal::GetEntry() { std::string VLogContext::str() const { std::stringstream result; for (const auto* entry : context_stack_) { - ICHECK_NOTNULL(entry); + TVM_FFI_ICHECK_NOTNULL(entry); result << entry->str(); result << ": "; } diff --git a/src/runtime/memory/memory_manager.cc b/src/runtime/memory/memory_manager.cc index db4d33be3789..15572551debe 100644 --- a/src/runtime/memory/memory_manager.cc +++ b/src/runtime/memory/memory_manager.cc @@ -43,15 +43,15 @@ Storage::Storage(Buffer buffer, Allocator* allocator) { } inline void VerifyDataType(DLDataType dtype) { - ICHECK_GE(dtype.lanes, 1); + TVM_FFI_ICHECK_GE(dtype.lanes, 1); if (dtype.code == kDLFloat) { - ICHECK_EQ(dtype.bits % 8, 0); + TVM_FFI_ICHECK_EQ(dtype.bits % 8, 0); } else { // allow uint1 as a special flag for bool. if (dtype.bits == 1 && dtype.code == kDLUInt) return; - ICHECK_EQ(dtype.bits % 8, 0); + TVM_FFI_ICHECK_EQ(dtype.bits % 8, 0); } - ICHECK_EQ(dtype.bits & (dtype.bits - 1), 0); + TVM_FFI_ICHECK_EQ(dtype.bits & (dtype.bits - 1), 0); } inline size_t GetDataAlignment(const DLDataType& dtype) { @@ -83,7 +83,7 @@ Tensor StorageObj::AllocTensorScoped(int64_t offset, ffi::Shape shape, DLDataTyp }; size_t needed_size = ffi::GetDataSize(shape.Product(), dtype); - ICHECK(offset + needed_size <= this->buffer.size) + TVM_FFI_ICHECK(offset + needed_size <= this->buffer.size) << "storage allocation failure, attempted to allocate " << needed_size << " at offset " << offset << " in region that is " << this->buffer.size << "bytes"; @@ -95,7 +95,7 @@ Tensor StorageObj::AllocTensor(int64_t offset, ffi::Shape shape, DLDataType dtyp VerifyDataType(dtype); size_t needed_size = ffi::GetDataSize(shape.Product(), dtype); - ICHECK(offset + needed_size <= this->buffer.size) + TVM_FFI_ICHECK(offset + needed_size <= this->buffer.size) << "storage allocation failure, attempted to allocate " << needed_size << " at offset " << offset << " in region that is " << this->buffer.size << "bytes"; class StorageAlloc { @@ -166,7 +166,7 @@ Allocator* GetDeviceSpecificAllocator(Device dev, AllocatorType type) { break; } default: - LOG(FATAL) << "Unknown allocator type: " << type; + TVM_FFI_THROW(InternalError) << "Unknown allocator type: " << type; } } return allocator; @@ -195,10 +195,11 @@ Allocator* MemoryManager::GetAllocator(Device dev, AllocatorType type) { std::lock_guard lock(m->mu_); auto it = m->allocators_.find(dev); if (it == m->allocators_.end()) { - LOG(FATAL) << "Allocator for " << dev << " has not been created yet."; + TVM_FFI_THROW(InternalError) << "Allocator for " << dev << " has not been created yet."; } if (it->second.find(type) == it->second.end()) { - LOG(FATAL) << "Allocator for " << dev << " of type " << type << " has not been created yet."; + TVM_FFI_THROW(InternalError) << "Allocator for " << dev << " of type " << type + << " has not been created yet."; } return it->second.at(type).get(); } @@ -254,8 +255,8 @@ Buffer Allocator::Alloc(Device dev, ffi::Shape shape, DLDataType type_hint, size_t size = ffi::GetDataSize(shape.Product(), type_hint); return Alloc(dev, size, alignment, type_hint); } - LOG(FATAL) << "Allocator cannot allocate data space with " - << "specified memory scope: " << mem_scope; + TVM_FFI_THROW(InternalError) << "Allocator cannot allocate data space with " + << "specified memory scope: " << mem_scope; return {}; } diff --git a/src/runtime/memory/pooled_allocator.h b/src/runtime/memory/pooled_allocator.h index 744c61987cdd..13c16f7ace64 100644 --- a/src/runtime/memory/pooled_allocator.h +++ b/src/runtime/memory/pooled_allocator.h @@ -78,7 +78,7 @@ class PooledAllocator : public Allocator { if (AllowMemoryScope(mem_scope)) { return Allocator::Alloc(dev, shape, type_hint, mem_scope); } - LOG(FATAL) << "This alloc should be implemented"; + TVM_FFI_THROW(InternalError) << "This alloc should be implemented"; return {}; } diff --git a/src/runtime/metal/metal_common.h b/src/runtime/metal/metal_common.h index f10489826a5a..8d72fac97a8a 100644 --- a/src/runtime/metal/metal_common.h +++ b/src/runtime/metal/metal_common.h @@ -116,7 +116,7 @@ class Stream { } [cb addCompletedHandler:^(id buffer) { if (buffer.status == MTLCommandBufferStatusError) { - ICHECK(buffer.error != nil); + TVM_FFI_ICHECK(buffer.error != nil); this->SetError(buffer.error.localizedDescription.UTF8String); } }]; @@ -155,8 +155,8 @@ class MetalWorkspace final : public DeviceAPI { ~MetalWorkspace(); // Get device for given device id GetDevice(Device dev) { - ICHECK_EQ(dev.device_type, kDLMetal); - ICHECK(dev.device_id >= 0 && static_cast(dev.device_id) < devices.size()) + TVM_FFI_ICHECK_EQ(dev.device_type, kDLMetal); + TVM_FFI_ICHECK(dev.device_id >= 0 && static_cast(dev.device_id) < devices.size()) << "Invalid Metal device_id=" << dev.device_id; return devices[dev.device_id]; } diff --git a/src/runtime/module.cc b/src/runtime/module.cc index c782cb96c09f..fde00b794487 100644 --- a/src/runtime/module.cc +++ b/src/runtime/module.cc @@ -63,7 +63,7 @@ bool RuntimeEnabled(const ffi::String& target_str) { if (!pf.has_value()) return false; return (*pf)(target).cast(); } else { - LOG(FATAL) << "Unknown optional runtime " << target; + TVM_FFI_THROW(InternalError) << "Unknown optional runtime " << target; } return tvm::ffi::Function::GetGlobal(f_name).has_value(); } diff --git a/src/runtime/opencl/opencl_common.h b/src/runtime/opencl/opencl_common.h index a7ed2a2824ec..a9fb5c01ec24 100644 --- a/src/runtime/opencl/opencl_common.h +++ b/src/runtime/opencl/opencl_common.h @@ -206,15 +206,18 @@ inline cl_channel_type DTypeToOpenCLChannelType(DLDataType data_type) { } else if (dtype == DataType::UInt(32)) { return CL_UNSIGNED_INT32; } - LOG(FATAL) << "data type is not supported in OpenCL runtime yet: " << dtype; + TVM_FFI_THROW(InternalError) << "data type is not supported in OpenCL runtime yet: " << dtype; } /*! * \brief Protected OpenCL call * \param func Expression to call. */ -#define OPENCL_CHECK_ERROR(e) \ - { ICHECK(e == CL_SUCCESS) << "OpenCL Error, code=" << e << ": " << cl::CLGetErrorString(e); } +#define OPENCL_CHECK_ERROR(e) \ + { \ + TVM_FFI_ICHECK(e == CL_SUCCESS) \ + << "OpenCL Error, code=" << e << ": " << cl::CLGetErrorString(e); \ + } #define OPENCL_CALL(func) \ { \ @@ -281,17 +284,17 @@ class OpenCLWorkspace : public DeviceAPI { virtual bool IsOpenCLDevice(Device dev) { return dev.device_type == kDLOpenCL; } // get the queue of the device cl_command_queue GetQueue(Device dev) { - ICHECK(IsOpenCLDevice(dev)); + TVM_FFI_ICHECK(IsOpenCLDevice(dev)); this->Init(); - ICHECK(dev.device_id >= 0 && static_cast(dev.device_id) < queues.size()) + TVM_FFI_ICHECK(dev.device_id >= 0 && static_cast(dev.device_id) < queues.size()) << "Invalid OpenCL device_id=" << dev.device_id << ". " << GetError(); return queues[dev.device_id]; } // get the event queue of the context std::vector& GetEventQueue(Device dev) { - ICHECK(IsOpenCLDevice(dev)); + TVM_FFI_ICHECK(IsOpenCLDevice(dev)); this->Init(); - ICHECK(dev.device_id >= 0 && static_cast(dev.device_id) < queues.size()) + TVM_FFI_ICHECK(dev.device_id >= 0 && static_cast(dev.device_id) < queues.size()) << "Invalid OpenCL device_id=" << dev.device_id << ". " << GetError(); return events[dev.device_id]; } diff --git a/src/runtime/opencl/opencl_device_api.cc b/src/runtime/opencl/opencl_device_api.cc index 700c4742fdc0..8820585bb612 100644 --- a/src/runtime/opencl/opencl_device_api.cc +++ b/src/runtime/opencl/opencl_device_api.cc @@ -59,7 +59,7 @@ struct ImageInfo { */ ImageInfo GetImageInfo(const cl::BufferDescriptor* desc, const DLTensor* tensor) { ImageInfo info{}; - ICHECK(tensor->dtype.lanes == 1) << "Image dtype has lanes: " << tensor->dtype.lanes; + TVM_FFI_ICHECK(tensor->dtype.lanes == 1) << "Image dtype has lanes: " << tensor->dtype.lanes; info.origin[0] = info.origin[1] = info.origin[2] = 0; info.row_pitch = 0; @@ -85,7 +85,8 @@ cl::BufferDescriptor::MemoryLayout cl::BufferDescriptor::MemoryLayoutFromScope( } else if (mem_scope.value() == "global.texture-nhwc") { return cl::BufferDescriptor::MemoryLayout::kImage2DNHWC; } - LOG(FATAL) << "No memory layout defined for memory of scope: " << mem_scope.value(); + TVM_FFI_THROW(InternalError) << "No memory layout defined for memory of scope: " + << mem_scope.value(); } ffi::String cl::BufferDescriptor::ScopeFromMemoryLayout(cl::BufferDescriptor::MemoryLayout layout) { @@ -99,8 +100,8 @@ ffi::String cl::BufferDescriptor::ScopeFromMemoryLayout(cl::BufferDescriptor::Me case cl::BufferDescriptor::MemoryLayout::kImage2DNHWC: return "global.texture-nhwc"; } - LOG(FATAL) << "No scope corresponding to the provided memory layout: " - << static_cast(layout); + TVM_FFI_THROW(InternalError) << "No scope corresponding to the provided memory layout: " + << static_cast(layout); return ""; } @@ -126,7 +127,8 @@ OpenCLWorkspace* OpenCLWorkspace::Global() { cl_device_id OpenCLWorkspace::GetCLDeviceID(int device_id) { this->Init(); - ICHECK_LT(device_id, devices.size()) << "Invalid device id " << device_id << ". " << GetError(); + TVM_FFI_ICHECK_LT(device_id, devices.size()) + << "Invalid device id " << device_id << ". " << GetError(); return devices[device_id]; } @@ -322,7 +324,7 @@ void* OpenCLWorkspace::AllocCLImage(Device dev, void* back_buffer, size_t width, size_t depth, size_t row_pitch, DLDataType type_hint, ffi::Optional mem_scope) { this->Init(); - ICHECK(std::string(mem_scope.value()).find("texture") != std::string::npos) + TVM_FFI_ICHECK(std::string(mem_scope.value()).find("texture") != std::string::npos) << "Expect texture scope while creating an Image object"; cl::BufferDescriptor* back_desc = static_cast(back_buffer); cl_device_id device_id = GetCLDeviceID(dev.device_id); @@ -465,7 +467,7 @@ void OpenCLWorkspace::SetNativePtr(const tvm::runtime::Tensor& narr, void* host_ OPENCL_CHECK_ERROR(err_code); #endif } else { - LOG(FATAL) << "Native Ptr not enabled over image objects"; + TVM_FFI_THROW(InternalError) << "Native Ptr not enabled over image objects"; } } @@ -506,8 +508,8 @@ void OpenCLWorkspace::FreeDataSpace(Device dev, void* ptr) { void OpenCLWorkspace::CopyDataFromTo(DLTensor* from, DLTensor* to, TVMStreamHandle stream) { this->Init(); size_t nbytes = GetDataSize(*from); - ICHECK_EQ(nbytes, GetDataSize(*to)); - ICHECK(IsContiguous(*from) && IsContiguous(*to)) + TVM_FFI_ICHECK_EQ(nbytes, GetDataSize(*to)); + TVM_FFI_ICHECK(IsContiguous(*from) && IsContiguous(*to)) << "CopyDataFromTo only support contiguous array for now"; if (IsOpenCLDevice(from->device) && IsOpenCLDevice(to->device)) { @@ -579,13 +581,13 @@ void OpenCLWorkspace::CopyDataFromTo(DLTensor* from, DLTensor* to, TVMStreamHand } OPENCL_CALL(clFinish(this->GetQueue(to->device))); } else { - LOG(FATAL) << "Expect copy from/to OpenCL or between OpenCL"; + TVM_FFI_THROW(InternalError) << "Expect copy from/to OpenCL or between OpenCL"; } } void OpenCLWorkspace::StreamSync(Device dev, TVMStreamHandle stream) { this->Init(); - ICHECK(stream == nullptr); + TVM_FFI_ICHECK(stream == nullptr); OPENCL_CALL(clFinish(this->GetQueue(dev))); } @@ -730,7 +732,7 @@ void OpenCLWorkspace::Init(const std::string& type_key, const std::string& devic initialized_ = true; return; } - ICHECK_EQ(this->queues.size(), 0U); + TVM_FFI_ICHECK_EQ(this->queues.size(), 0U); cl_int err_code; for (auto& [platform, devices] : device_map) { this->platform_ids.push_back(platform); @@ -768,9 +770,9 @@ TVM_FFI_STATIC_INIT_BLOCK() { int32_t device_type = args[0].cast(); int32_t device_id = args[1].cast(); auto scope = args[4].cast(); - CHECK(scope.find("texture") != std::string::npos); + TVM_FFI_ICHECK(scope.find("texture") != std::string::npos); int64_t ndim = args[5].cast(); - CHECK_EQ(ndim, 3); + TVM_FFI_ICHECK_EQ(ndim, 3); int64_t* shape = static_cast(args[6].cast()); int64_t width = shape[0]; int64_t height = shape[1]; @@ -794,7 +796,7 @@ TVM_FFI_STATIC_INIT_BLOCK() { int32_t device_type = args[0].cast(); int32_t device_id = args[1].cast(); auto scope = args[2].cast(); - CHECK(scope.find("texture") != std::string::npos); + TVM_FFI_ICHECK(scope.find("texture") != std::string::npos); void* data = args[3].cast(); OpenCLWorkspace* ptr = OpenCLWorkspace::Global(); Device dev; @@ -869,7 +871,7 @@ class OpenCLPooledAllocator final : public memory::PooledAllocator { DLOG(INFO) << "allocate " << size << " B, used memory " << used_memory_ << " B"; return buf; } - LOG(FATAL) << "Unsupported memory scope for this Allocator:" << mem_scope; + TVM_FFI_THROW(InternalError) << "Unsupported memory scope for this Allocator:" << mem_scope; return {}; } diff --git a/src/runtime/opencl/opencl_module.cc b/src/runtime/opencl/opencl_module.cc index b488116d7464..c7f873a02180 100644 --- a/src/runtime/opencl/opencl_module.cc +++ b/src/runtime/opencl/opencl_module.cc @@ -53,7 +53,7 @@ class OpenCLWrappedFunc { // invoke the function with void arguments void operator()(ffi::PackedArgs args, ffi::Any* rv, void** void_args) const { - ICHECK(w_->devices.size() > 0) << "No OpenCL device"; + TVM_FFI_ICHECK(w_->devices.size() > 0) << "No OpenCL device"; cl::OpenCLThreadEntry* t = w_->GetThreadEntry(); // get the kernel from thread local kernel table. if (entry_.kernel_id >= t->kernel_table.size()) { @@ -137,7 +137,7 @@ cl::OpenCLWorkspace* OpenCLModuleNodeBase::GetGlobalWorkspace() { ffi::Optional OpenCLModuleNodeBase::GetFunction(const ffi::String& name) { ObjectPtr sptr_to_self = ffi::GetObjectPtr(this); - ICHECK_EQ(sptr_to_self.get(), this); + TVM_FFI_ICHECK_EQ(sptr_to_self.get(), this); auto opt_info = fmap_.Get(name); if (!opt_info.has_value()) return std::nullopt; FunctionInfo info = opt_info.value(); @@ -145,13 +145,13 @@ ffi::Optional OpenCLModuleNodeBase::GetFunction(const ffi::String std::vector arg_size(info->arg_types.size()); for (size_t i = 0; i < info->arg_types.size(); ++i) { DLDataType t = info->arg_types[i]; - ICHECK_EQ(t.lanes, 1U); + TVM_FFI_ICHECK_EQ(t.lanes, 1U); if (t.code == kDLOpaqueHandle) { // specially store pointer type size in OpenCL driver arg_size[i] = sizeof(void*); } else { uint32_t bits = t.bits; - ICHECK_EQ(bits % 8, 0U); + TVM_FFI_ICHECK_EQ(bits % 8, 0U); arg_size[i] = bits / 8; } } @@ -162,7 +162,7 @@ ffi::Optional OpenCLModuleNodeBase::GetFunction(const ffi::String void OpenCLModuleNode::WriteToFile(const ffi::String& file_name, const ffi::String& format) const { std::string fmt = GetFileFormat(file_name, format); - ICHECK_EQ(fmt, fmt_) << "Can only save to format=" << fmt_; + TVM_FFI_ICHECK_EQ(fmt, fmt_) << "Can only save to format=" << fmt_; std::string meta_file = GetMetaFilePath(file_name); SaveMetaDataToFile(meta_file, fmap_); SaveBinaryToFile(file_name, data_); @@ -205,10 +205,10 @@ void OpenCLModuleNode::Init() { // split into source artifacts for each kernel parsed_kernels_ = SplitKernels(InspectSource("cl")); - ICHECK(!parsed_kernels_.empty()) << "The OpenCL module expects a kernel delimited " - << "source from code generation, but no kernel " - << "delimiter was found."; - ICHECK_EQ(fmap_.size(), parsed_kernels_.size()) + TVM_FFI_ICHECK(!parsed_kernels_.empty()) << "The OpenCL module expects a kernel delimited " + << "source from code generation, but no kernel " + << "delimiter was found."; + TVM_FFI_ICHECK_EQ(fmap_.size(), parsed_kernels_.size()) << "The number of parsed kernel sources does not match the number of kernel functions"; } @@ -216,7 +216,7 @@ bool OpenCLModuleNode::IsProgramCreated(const std::string& func_name, int device auto size = programs_[func_name].size(); if (size > 0 && programs_[func_name][device_id] != nullptr) return true; auto dev_size = GetGlobalWorkspace()->devices.size(); - ICHECK(device_id < static_cast(dev_size)) + TVM_FFI_ICHECK(device_id < static_cast(dev_size)) << "Device id " << device_id << " is bigger than number of available devices"; // zero initialize cl_program pointers for each device kernel if (size == 0) programs_[func_name].resize(dev_size, nullptr); @@ -247,7 +247,7 @@ cl_kernel OpenCLModuleNode::InstallKernel(cl::OpenCLWorkspace* w, cl::OpenCLThre clCreateProgramWithBinary(w->contexts[platform], 1, &dev, &len, &s, nullptr, &err); OPENCL_CHECK_ERROR(err); } else { - LOG(FATAL) << "Unknown OpenCL format " << fmt_; + TVM_FFI_THROW(InternalError) << "Unknown OpenCL format " << fmt_; } // build program cl_int err; @@ -261,9 +261,9 @@ cl_kernel OpenCLModuleNode::InstallKernel(cl::OpenCLWorkspace* w, cl::OpenCLThre log.resize(len); clGetProgramBuildInfo(programs_[func_name][device_id], dev, CL_PROGRAM_BUILD_LOG, len, &log[0], nullptr); - LOG(FATAL) << "OpenCL build error for device=" << dev - << "\nError: " << cl::CLGetErrorString(err) << "\n" - << log; + TVM_FFI_THROW(InternalError) << "OpenCL build error for device=" << dev + << "\nError: " << cl::CLGetErrorString(err) << "\n" + << log; } } // build kernel @@ -311,7 +311,7 @@ void OpenCLModuleNode::SetPreCompiledPrograms(const std::string& bytes) { log.resize(len); clGetProgramBuildInfo(programs_[name][device_id], dev, CL_PROGRAM_BUILD_LOG, len, &log[0], nullptr); - LOG(FATAL) << "OpenCL build error for device=" << dev << "\n" << log; + TVM_FFI_THROW(InternalError) << "OpenCL build error for device=" << dev << "\n" << log; } } } @@ -333,7 +333,7 @@ std::string OpenCLModuleNode::GetPreCompiledPrograms() { size_t size; clGetProgramInfo(programs_[name][device_id], CL_PROGRAM_BINARY_SIZES, sizeof(size_t), &size, nullptr); - ICHECK(size > 0) << "Size of binary is 0"; + TVM_FFI_ICHECK(size > 0) << "Size of binary is 0"; std::vector bin_vector(size); unsigned char* binary = bin_vector.data(); clGetProgramInfo(programs_[name][device_id], CL_PROGRAM_BINARIES, sizeof(unsigned char*), @@ -347,7 +347,7 @@ std::string OpenCLModuleNode::GetPreCompiledPrograms() { ffi::Optional OpenCLModuleNode::GetFunction(const ffi::String& name) { ObjectPtr sptr_to_self = ffi::GetObjectPtr(this); - ICHECK_EQ(sptr_to_self.get(), this); + TVM_FFI_ICHECK_EQ(sptr_to_self.get(), this); if (name == "opencl.GetPreCompiledPrograms") { return ffi::Function([sptr_to_self, this](ffi::PackedArgs args, ffi::Any* rv) { *rv = this->GetPreCompiledPrograms(); @@ -384,7 +384,7 @@ ffi::Module OpenCLModuleLoadFromBytes(const ffi::Bytes& bytes) { ffi::Map fmap; std::string fmt; stream.Read(&fmt); - ICHECK(stream.Read(&fmap)); + TVM_FFI_ICHECK(stream.Read(&fmap)); stream.Read(&data); return OpenCLModuleCreate(data, fmt, fmap, std::string()); } diff --git a/src/runtime/opencl/opencl_module_spirv.cc b/src/runtime/opencl/opencl_module_spirv.cc index 6893ebdbaf1d..0125c4121ace 100644 --- a/src/runtime/opencl/opencl_module_spirv.cc +++ b/src/runtime/opencl/opencl_module_spirv.cc @@ -55,7 +55,7 @@ class OpenCLSPIRVModuleNode : public OpenCLModuleNodeBase { void OpenCLSPIRVModuleNode::WriteToFile(const ffi::String& file_name, const ffi::String& format) const { // TODO(masahi): How SPIRV binaries should be save to a file? - LOG(FATAL) << "Not implemented."; + TVM_FFI_THROW(InternalError) << "Not implemented."; } ffi::Bytes OpenCLSPIRVModuleNode::SaveToBytes() const { @@ -116,7 +116,7 @@ cl_kernel OpenCLSPIRVModuleNode::InstallKernel(cl::OpenCLWorkspace* w, cl::OpenC log.resize(len); clGetProgramBuildInfo(programs_[func_name][device_id], dev, CL_PROGRAM_BUILD_LOG, len, &log[0], nullptr); - LOG(FATAL) << "OpenCL build error for device=" << dev << "\n" << log; + TVM_FFI_THROW(InternalError) << "OpenCL build error for device=" << dev << "\n" << log; } } // build kernel diff --git a/src/runtime/opencl/opencl_wrapper/opencl_wrapper.cc b/src/runtime/opencl/opencl_wrapper/opencl_wrapper.cc index 948d1800b77f..e4c4a1a9af31 100644 --- a/src/runtime/opencl/opencl_wrapper/opencl_wrapper.cc +++ b/src/runtime/opencl/opencl_wrapper/opencl_wrapper.cc @@ -124,7 +124,7 @@ class LibOpenCLWrapper { #endif if (m_libHandler != nullptr) return; } - ICHECK(m_libHandler != nullptr) << "Error! Cannot open libOpenCL!"; + TVM_FFI_ICHECK(m_libHandler != nullptr) << "Error! Cannot open libOpenCL!"; } private: diff --git a/src/runtime/pack_args.h b/src/runtime/pack_args.h index 4d938c1c1dc8..0cbb631b1669 100644 --- a/src/runtime/pack_args.h +++ b/src/runtime/pack_args.h @@ -139,7 +139,7 @@ enum ArgConvertCode { }; inline ArgConvertCode GetArgConvertCode(DLDataType t) { - ICHECK_EQ(t.lanes, 1U) << "Cannot pass vector type argument to device function for now"; + TVM_FFI_ICHECK_EQ(t.lanes, 1U) << "Cannot pass vector type argument to device function for now"; if (t.code == kDLInt) { if (t.bits == 64U) return INT64_TO_INT64; if (t.bits == 32U) return INT64_TO_INT32; @@ -151,7 +151,7 @@ inline ArgConvertCode GetArgConvertCode(DLDataType t) { } else if (t.code == kDLOpaqueHandle) { return HANDLE_TO_HANDLE; } - LOG(FATAL) << "Cannot handle " << t << " as device function argument"; + TVM_FFI_THROW(InternalError) << "Cannot handle " << t << " as device function argument"; TVM_FFI_UNREACHABLE(); } @@ -234,7 +234,7 @@ inline ffi::Function PackFuncNonBufferArg_(F f, int base, } case HANDLE_TO_HANDLE: case HANDLE_TO_TENSORMAP: { - LOG(FATAL) << "not reached"; + TVM_FFI_THROW(InternalError) << "not reached"; break; } } @@ -297,7 +297,7 @@ inline ffi::Function PackFuncPackedArgAligned_(F f, const std::vector& arg_types) { } } for (size_t i = base; i < arg_types.size(); ++i) { - ICHECK(arg_types[i].code != kDLOpaqueHandle) << "Device function need to be organized"; + TVM_FFI_ICHECK(arg_types[i].code != kDLOpaqueHandle) << "Device function need to be organized"; } return base; } diff --git a/src/runtime/profiling.cc b/src/runtime/profiling.cc index 91e2a3f7eb76..99b5d77a6e9e 100644 --- a/src/runtime/profiling.cc +++ b/src/runtime/profiling.cc @@ -194,7 +194,7 @@ std::vector ToShape(Tensor shape_tensor) { // Otherwise we should be rank-1, and we will extract the number of dimensions // for the output vector. - ICHECK_EQ(rank, 1U) << "shape tensor should be a k-length vector, found " << rank; + TVM_FFI_ICHECK_EQ(rank, 1U) << "shape tensor should be a k-length vector, found " << rank; int64_t ndim = shape_tensor.Shape().at(0); shape.resize(ndim); @@ -206,7 +206,7 @@ std::vector ToShape(Tensor shape_tensor) { int64_t* dims = reinterpret_cast(dl_tensor->data); shape.assign(dims, dims + ndim); } else { - LOG(FATAL) << "invalid shape tensor datatype: " << dtype; + TVM_FFI_THROW(InternalError) << "invalid shape tensor datatype: " << dtype; } return shape; @@ -318,7 +318,7 @@ void metric_as_json(std::ostream& os, ffi::Any o) { os << "{\"ratio\":" << std::setprecision(std::numeric_limits::max_digits10) << std::fixed << n->ratio << "}"; } else { - LOG(FATAL) << "Unprintable type " << o.GetTypeKey(); + TVM_FFI_THROW(InternalError) << "Unprintable type " << o.GetTypeKey(); } } } // namespace @@ -390,7 +390,7 @@ ffi::String ReportNode::AsJSON() const { // and Percent; average for Ratio; and assumes all Strings are the same. All // ObjectRefs in metrics must have the same type. Any AggregateMetric(const std::vector& metrics) { - ICHECK_GT(metrics.size(), 0) << "Must pass a non-zero number of metrics"; + TVM_FFI_ICHECK_GT(metrics.size(), 0) << "Must pass a non-zero number of metrics"; if (metrics[0].as()) { double sum = 0; for (auto& metric : metrics) { @@ -424,9 +424,10 @@ Any AggregateMetric(const std::vector& metrics) { // Assume all strings in metrics are the same. return metrics[0]; } else { - LOG(FATAL) << "Can only aggregate metrics with types DurationNode, CountNode, " - "PercentNode, RatioNode, and String, but got " - << metrics[0].GetTypeKey(); + TVM_FFI_THROW(InternalError) + << "Can only aggregate metrics with types DurationNode, CountNode, " + "PercentNode, RatioNode, and String, but got " + << metrics[0].GetTypeKey(); return ffi::Any(); // To silence warnings } } @@ -467,7 +468,7 @@ static ffi::String print_metric(ffi::Any metric) { } else if (auto opt_str = metric.as()) { val = *opt_str; } else { - LOG(FATAL) << "Cannot print metric of type " << metric.GetTypeKey(); + TVM_FFI_THROW(InternalError) << "Cannot print metric of type " << metric.GetTypeKey(); } return val; } @@ -641,8 +642,7 @@ ffi::String ReportNode::AsTable(bool sort, bool aggregate, bool compute_col_sums if (row < cols[col].size()) { s << std::setw(widths[col]) << cols[col][row] << " "; } else { - s << std::setw(widths[col]) << "" - << " "; + s << std::setw(widths[col]) << " "; } } s << std::endl; @@ -734,8 +734,8 @@ ffi::Map parse_metrics(const json::Object& obj) { } else if (metric_value_name == "string") { o = ffi::String(type_val.cast()); } else { - LOG(FATAL) << "Cannot parse metric of type " << metric_value_name - << " valid types are microseconds, percent, count."; + TVM_FFI_THROW(InternalError) << "Cannot parse metric of type " << metric_value_name + << " valid types are microseconds, percent, count."; } } metrics.Set(metric_name, o); @@ -786,48 +786,49 @@ ffi::Function ProfileFunction(ffi::Module mod, std::string func_name, int device int device_id, int warmup_iters, ffi::Array collectors) { // Module::GetFunction is not const, so this lambda has to be mutable - return ffi::Function::FromPacked([=](const ffi::AnyView* args, int32_t num_args, - ffi::Any* ret) mutable { - auto optf = mod->GetFunction(func_name); - CHECK(optf.has_value()) << "There is no function called \"" << func_name << "\" in the module"; - auto f = *optf; - Device dev{static_cast(device_type), device_id}; - - // warmup - for (int i = 0; i < warmup_iters; i++) { - f.CallPacked(args, num_args, ret); - } - - for (auto& collector : collectors) { - collector->Init({DeviceWrapper(dev)}); - } - std::vector> results; - results.reserve(collectors.size()); - std::vector> collector_data; - collector_data.reserve(collectors.size()); - for (auto& collector : collectors) { - ObjectRef o = collector->Start(dev); - // If not defined, then the collector cannot time this device. - if (o.defined()) { - collector_data.push_back({collector, o}); - } - } + return ffi::Function::FromPacked( + [=](const ffi::AnyView* args, int32_t num_args, ffi::Any* ret) mutable { + auto optf = mod->GetFunction(func_name); + TVM_FFI_ICHECK(optf.has_value()) + << "There is no function called \"" << func_name << "\" in the module"; + auto f = *optf; + Device dev{static_cast(device_type), device_id}; + + // warmup + for (int i = 0; i < warmup_iters; i++) { + f.CallPacked(args, num_args, ret); + } - // TODO(tkonolige): repeated calls if the runtime is small? - f.CallPacked(args, num_args, ret); + for (auto& collector : collectors) { + collector->Init({DeviceWrapper(dev)}); + } + std::vector> results; + results.reserve(collectors.size()); + std::vector> collector_data; + collector_data.reserve(collectors.size()); + for (auto& collector : collectors) { + ObjectRef o = collector->Start(dev); + // If not defined, then the collector cannot time this device. + if (o.defined()) { + collector_data.push_back({collector, o}); + } + } - for (auto& kv : collector_data) { - results.push_back(kv.first->Stop(kv.second)); - } - ffi::Map combined_results; - for (auto m : results) { - for (auto p : m) { - // assume that there is no shared metric name between collectors - combined_results.Set(p.first, p.second); - } - } - *ret = combined_results; - }); + // TODO(tkonolige): repeated calls if the runtime is small? + f.CallPacked(args, num_args, ret); + + for (auto& kv : collector_data) { + results.push_back(kv.first->Stop(kv.second)); + } + ffi::Map combined_results; + for (auto m : results) { + for (auto p : m) { + // assume that there is no shared metric name between collectors + combined_results.Set(p.first, p.second); + } + } + *ret = combined_results; + }); } TVM_FFI_STATIC_INIT_BLOCK() { @@ -837,7 +838,7 @@ TVM_FFI_STATIC_INIT_BLOCK() { [](ffi::Module mod, ffi::String func_name, int device_type, int device_id, int warmup_iters, ffi::Array collectors) { if (mod->kind() == std::string("rpc")) { - LOG(FATAL) + TVM_FFI_THROW(InternalError) << "Profiling a module over RPC is not yet supported"; // because we can't send // MetricCollectors over rpc. throw; @@ -851,7 +852,7 @@ ffi::Function WrapTimeEvaluator(ffi::Function pf, Device dev, int number, int re int min_repeat_ms, int limit_zero_time_iterations, int cooldown_interval_ms, int repeats_to_cooldown, int cache_flush_bytes, ffi::Function f_preproc) { - ICHECK(pf != nullptr); + TVM_FFI_ICHECK(pf != nullptr); auto ftimer = [pf, dev, number, repeat, min_repeat_ms, limit_zero_time_iterations, cooldown_interval_ms, repeats_to_cooldown, cache_flush_bytes, diff --git a/src/runtime/rocm/rocm_common.h b/src/runtime/rocm/rocm_common.h index ec3e744d3034..3c07561c38d1 100644 --- a/src/runtime/rocm/rocm_common.h +++ b/src/runtime/rocm/rocm_common.h @@ -35,18 +35,19 @@ namespace tvm { namespace runtime { -#define ROCM_DRIVER_CALL(x) \ - { \ - hipError_t result = x; \ - if (result != hipSuccess && result != hipErrorDeinitialized) { \ - LOG(FATAL) << "ROCM HIP Error: " #x " failed with error: " << hipGetErrorString(result); \ - } \ +#define ROCM_DRIVER_CALL(x) \ + { \ + hipError_t result = x; \ + if (result != hipSuccess && result != hipErrorDeinitialized) { \ + TVM_FFI_THROW(InternalError) \ + << "ROCM HIP Error: " #x " failed with error: " << hipGetErrorString(result); \ + } \ } -#define ROCM_CALL(func) \ - { \ - hipError_t e = (func); \ - ICHECK(e == hipSuccess) << "ROCM HIP: " << hipGetErrorString(e); \ +#define ROCM_CALL(func) \ + { \ + hipError_t e = (func); \ + TVM_FFI_ICHECK(e == hipSuccess) << "ROCM HIP: " << hipGetErrorString(e); \ } /*! \brief Thread local workspace */ diff --git a/src/runtime/rocm/rocm_device_api.cc b/src/runtime/rocm/rocm_device_api.cc index 6408892bb5fa..ed7bd98ffe5d 100644 --- a/src/runtime/rocm/rocm_device_api.cc +++ b/src/runtime/rocm/rocm_device_api.cc @@ -147,7 +147,7 @@ class ROCMDeviceAPI final : public DeviceAPI { *rv = value; } void* AllocDataSpace(Device dev, size_t nbytes, size_t alignment, DLDataType type_hint) final { - ICHECK_EQ(256 % alignment, 0U) << "ROCM space is aligned at 256 bytes"; + TVM_FFI_ICHECK_EQ(256 % alignment, 0U) << "ROCM space is aligned at 256 bytes"; void* ret; if (dev.device_type == kDLROCMHost) { VLOG(1) << "allocating " << nbytes << "bytes on host"; @@ -205,7 +205,7 @@ class ROCMDeviceAPI final : public DeviceAPI { ROCM_CALL(hipSetDevice(dev_to.device_id)); GPUCopy(from, to, size, hipMemcpyHostToDevice, hip_stream); } else { - LOG(FATAL) << "expect copy from/to GPU or between GPU"; + TVM_FFI_THROW(InternalError) << "expect copy from/to GPU or between GPU"; } } diff --git a/src/runtime/rocm/rocm_module.cc b/src/runtime/rocm/rocm_module.cc index d1e6f874434b..56f929c3c284 100644 --- a/src/runtime/rocm/rocm_module.cc +++ b/src/runtime/rocm/rocm_module.cc @@ -75,7 +75,7 @@ class ROCMModuleNode : public ffi::ModuleObj { std::string fmt = GetFileFormat(file_name, format); std::string meta_file = GetMetaFilePath(file_name); // note: llvm and asm formats are not laodable, so we don't save them - ICHECK_EQ(fmt, fmt_) << "Can only save to format=" << fmt_; + TVM_FFI_ICHECK_EQ(fmt, fmt_) << "Can only save to format=" << fmt_; SaveMetaDataToFile(meta_file, fmap_); SaveBinaryToFile(file_name, data_); } @@ -113,8 +113,8 @@ class ROCMModuleNode : public ffi::ModuleObj { hipFunction_t func; hipError_t result = hipModuleGetFunction(&func, module_[device_id], func_name.c_str()); if (result != hipSuccess) { - LOG(FATAL) << "ROCMError: hipModuleGetFunction " << func_name - << " failed with error: " << hipGetErrorString(result); + TVM_FFI_THROW(ROCMError) << "hipModuleGetFunction " << func_name + << " failed with error: " << hipGetErrorString(result); } return func; } @@ -129,7 +129,7 @@ class ROCMModuleNode : public ffi::ModuleObj { size_t nbytes = 0; ROCM_DRIVER_CALL(hipModuleGetGlobal(&global, &nbytes, module_[device_id], global_name.c_str())); - ICHECK_EQ(nbytes, expect_nbytes); + TVM_FFI_ICHECK_EQ(nbytes, expect_nbytes); return global; } @@ -199,7 +199,7 @@ class ROCMWrappedFunc { ffi::Optional ROCMModuleNode::GetFunction(const ffi::String& name) { ObjectPtr sptr_to_self = ffi::GetObjectPtr(this); - ICHECK_EQ(sptr_to_self.get(), this); + TVM_FFI_ICHECK_EQ(sptr_to_self.get(), this); auto opt_info = fmap_.Get(name); if (!opt_info.has_value()) return std::nullopt; FunctionInfo info = opt_info.value(); @@ -231,7 +231,7 @@ ffi::Module ROCMModuleLoadFromBytes(const ffi::Bytes& bytes) { ffi::Map fmap; std::string fmt; stream.Read(&fmt); - ICHECK(stream.Read(&fmap)); + TVM_FFI_ICHECK(stream.Read(&fmap)); stream.Read(&data); return ROCMModuleCreate(data, fmt, fmap, std::string(), std::string()); } diff --git a/src/runtime/rpc/rpc_channel.cc b/src/runtime/rpc/rpc_channel.cc index 50f7195a2224..f462dac3d257 100644 --- a/src/runtime/rpc/rpc_channel.cc +++ b/src/runtime/rpc/rpc_channel.cc @@ -35,7 +35,7 @@ size_t CallbackChannel::Send(const void* data, size_t size) { bytes.size = size; int64_t n = fsend_(&bytes).cast(); if (n == -1) { - LOG(FATAL) << "CallbackChannel::Send"; + TVM_FFI_THROW(InternalError) << "CallbackChannel::Send"; } return static_cast(n); } @@ -44,7 +44,7 @@ size_t CallbackChannel::Recv(void* data, size_t size) { Any ret = frecv_(size); auto opt_bytes = ret.try_cast(); - CHECK(opt_bytes.has_value()) << "CallbackChannel::Recv"; + TVM_FFI_ICHECK(opt_bytes.has_value()) << "CallbackChannel::Recv"; ffi::Bytes bytes = std::move(opt_bytes.value()); memcpy(static_cast(data), bytes.data(), bytes.size()); diff --git a/src/runtime/rpc/rpc_device_api.cc b/src/runtime/rpc/rpc_device_api.cc index 88e01255d82a..dd4b1993141e 100644 --- a/src/runtime/rpc/rpc_device_api.cc +++ b/src/runtime/rpc/rpc_device_api.cc @@ -82,7 +82,7 @@ class RPCDeviceAPI final : public DeviceAPI { DLDevice dev_from = from->device; DLDevice dev_to = to->device; if (IsRPCSessionDevice(dev_from) && IsRPCSessionDevice(dev_to)) { - ICHECK(dev_from.device_type == dev_to.device_type) + TVM_FFI_ICHECK(dev_from.device_type == dev_to.device_type) << "Cannot copy across two different remote session"; DLTensor from_tensor = *from; from_tensor.device = RemoveRPCSessionMask(dev_from); @@ -108,7 +108,7 @@ class RPCDeviceAPI final : public DeviceAPI { size_t nbytes = GetDataSize(*from); GetSess(dev_to)->CopyToRemote(from_bytes, &to_tensor, nbytes); } else { - LOG(FATAL) << "expect copy from/to remote or between remote"; + TVM_FFI_THROW(InternalError) << "expect copy from/to remote or between remote"; } } @@ -141,7 +141,7 @@ class RPCDeviceAPI final : public DeviceAPI { void CopyDataFromTo(const void* from, size_t from_offset, void* to, size_t to_offset, size_t num_bytes, Device dev_from, Device dev_to, DLDataType type_hint, TVMStreamHandle stream) final { - LOG(FATAL) << "Not implemented."; + TVM_FFI_THROW(InternalError) << "Not implemented."; } private: diff --git a/src/runtime/rpc/rpc_endpoint.cc b/src/runtime/rpc/rpc_endpoint.cc index 92bb188c4ff5..24687e504c48 100644 --- a/src/runtime/rpc/rpc_endpoint.cc +++ b/src/runtime/rpc/rpc_endpoint.cc @@ -121,7 +121,7 @@ class RPCEndpoint::EventHandler : public support::Stream { break; case kRecvPacketNumBytes: { uint64_t packet_nbytes; - ICHECK(this->Read(&packet_nbytes)); + TVM_FFI_ICHECK(this->Read(&packet_nbytes)); if (packet_nbytes != 0) { this->SwitchToState(kProcessPacket); this->RequestBytes(packet_nbytes); @@ -179,18 +179,20 @@ class RPCEndpoint::EventHandler : public support::Stream { continue; if (const Object* obj = args[i].as()) { if (!obj->IsInstance()) { - LOG(FATAL) << "ValueError: Cannot pass argument " << i << ", type " << obj->GetTypeKey() - << " (type_index = " << obj->type_index() << ")"; + TVM_FFI_THROW(ValueError) + << "Cannot pass argument " << i << ", type " << obj->GetTypeKey() + << " (type_index = " << obj->type_index() << ")"; } } else if (auto opt_device = args[i].as()) { DLDevice dev = opt_device.value(); - ICHECK(!IsRPCSessionDevice(dev)) << "InternalError: cannot pass RPC device in the channel"; + TVM_FFI_CHECK(!IsRPCSessionDevice(dev), InternalError) + << "cannot pass RPC device in the channel"; } } } void ThrowError(RPCServerStatus code, RPCCode info = RPCCode::kNone) { - LOG(FATAL) << "RPCServerError:" << RPCServerStatusToString(code); + TVM_FFI_THROW(RPCServerError) << RPCServerStatusToString(code); } uint64_t PackedSeqGetNumBytes(const ffi::AnyView* packed_args, int num_args, bool client_mode) { @@ -240,9 +242,9 @@ class RPCEndpoint::EventHandler : public support::Stream { this->template Write((*opt_bytes).size()); this->template WriteArray((*opt_bytes).data(), (*opt_bytes).size()); } else { - LOG(FATAL) << "ValueError: Object type is not supported in RPC calling convention: " - << any_view_ptr->GetTypeKey() << " (type_index = " << any_view_ptr->type_index() - << ")"; + TVM_FFI_THROW(ValueError) << "Object type is not supported in RPC calling convention: " + << any_view_ptr->GetTypeKey() + << " (type_index = " << any_view_ptr->type_index() << ")"; } } uint64_t GetFFIAnyProtocolBytes(const TVMFFIAny* in) { @@ -254,9 +256,9 @@ class RPCEndpoint::EventHandler : public support::Stream { } else if (auto opt_bytes = any_view_ptr->as()) { return sizeof(uint32_t) + sizeof(uint64_t) + (*opt_bytes).size(); } else { - LOG(FATAL) << "ValueError: Object type is not supported in RPC calling convention: " - << any_view_ptr->GetTypeKey() << " (type_index = " << any_view_ptr->type_index() - << ")"; + TVM_FFI_THROW(ValueError) << "Object type is not supported in RPC calling convention: " + << any_view_ptr->GetTypeKey() + << " (type_index = " << any_view_ptr->type_index() << ")"; TVM_FFI_UNREACHABLE(); } } @@ -298,8 +300,9 @@ class RPCEndpoint::EventHandler : public support::Stream { *reinterpret_cast(out) = ret; any_arena_.emplace_back(ret); } else { - LOG(FATAL) << "ValueError: Object type is not supported in Disco calling convention: " - << Object::TypeIndex2Key(type_index) << " (type_index = " << type_index << ")"; + TVM_FFI_THROW(ValueError) << "Object type is not supported in Disco calling convention: " + << Object::TypeIndex2Key(type_index) + << " (type_index = " << type_index << ")"; } } @@ -346,7 +349,7 @@ class RPCEndpoint::EventHandler : public support::Stream { void SwitchToState(State state) { // invariant if (state != kCopyAckReceived) { - ICHECK_EQ(pending_request_bytes_, 0U) << "state=" << state; + TVM_FFI_ICHECK_EQ(pending_request_bytes_, 0U) << "state=" << state; } // need to actively flush the writer // so the data get pushed out. @@ -354,7 +357,7 @@ class RPCEndpoint::EventHandler : public support::Stream { flush_writer_(); } state_ = state; - ICHECK(state != kInitHeader) << "cannot switch to init header"; + TVM_FFI_ICHECK(state != kInitHeader) << "cannot switch to init header"; if (state == kRecvPacketNumBytes) { this->RequestBytes(sizeof(uint64_t)); // recycle arena for the next session. @@ -372,7 +375,7 @@ class RPCEndpoint::EventHandler : public support::Stream { this->RequestBytes(len); return; } else { - ICHECK_EQ(init_header_step_, 1); + TVM_FFI_ICHECK_EQ(init_header_step_, 1); this->ReadArray(remote_key_->data(), remote_key_->length()); this->SwitchToState(kRecvPacketNumBytes); } @@ -416,7 +419,7 @@ class RPCEndpoint::EventHandler : public support::Stream { break; } default: - LOG(FATAL) << "Unknown event " << static_cast(code); + TVM_FFI_THROW(InternalError) << "Unknown event " << static_cast(code); } } } @@ -465,13 +468,15 @@ class RPCEndpoint::EventHandler : public support::Stream { // switch to the state before sending exception. this->SwitchToState(kRecvPacketNumBytes); ffi::String msg = args[0].cast(); - if (!support::StartsWith(msg, "RPCSessionTimeoutError: ")) { + if (support::StartsWith(msg, "RPCSessionTimeoutError: ")) { + TVM_FFI_THROW(RPCSessionTimeoutError) << msg; + } else { msg = "RPCError: Error caught from RPC call:\n" + msg; + TVM_FFI_THROW(RPCError) << msg; } - LOG(FATAL) << msg; } - ICHECK(setreturn != nullptr) << "fsetreturn not available"; + TVM_FFI_ICHECK(setreturn != nullptr) << "fsetreturn not available"; setreturn(args); this->SwitchToState(kReturnReceived); @@ -594,10 +599,10 @@ class RPCEndpoint::EventHandler : public support::Stream { ffi::PackedArgs args = RecvPackedSeq(); try { - ICHECK(serving_session_ == nullptr) << "Server has already been initialized"; + TVM_FFI_ICHECK(serving_session_ == nullptr) << "Server has already been initialized"; std::string server_protocol_ver = kRPCProtocolVer; - ICHECK_EQ(client_protocol_ver, server_protocol_ver) + TVM_FFI_ICHECK_EQ(client_protocol_ver, server_protocol_ver) << "Server[" << name_ << "]: Client protocol version mismatch with the server " << " server protocol=" << server_protocol_ver << ", client protocol=" << client_protocol_ver; @@ -614,24 +619,27 @@ class RPCEndpoint::EventHandler : public support::Stream { } auto fconstructor = tvm::ffi::Function::GetGlobal(constructor_name); - ICHECK(fconstructor.has_value()) << " Cannot find session constructor " << constructor_name; + TVM_FFI_ICHECK(fconstructor.has_value()) + << " Cannot find session constructor " << constructor_name; ffi::Any con_ret; try { fconstructor->CallPacked(constructor_args, &con_ret); } catch (const Error& e) { - LOG(FATAL) << "Server[" << name_ << "]:" - << " Error caught from session constructor " << constructor_name << ":\n" - << e.what(); + TVM_FFI_THROW(InternalError) + << "Server[" << name_ << "]:" + << " Error caught from session constructor " << constructor_name << ":\n" + << e.what(); } auto opt_con_ret = con_ret.as(); // Legacy ABI translation - ICHECK(opt_con_ret.has_value()) + TVM_FFI_ICHECK(opt_con_ret.has_value()) << "Server[" << name_ << "]:" << " Constructor " << constructor_name << " need to return an RPCModule"; ffi::Module mod = opt_con_ret.value(); std::string tkey = mod->kind(); - ICHECK_EQ(tkey, "rpc") << "Constructor " << constructor_name << " to return an RPCModule"; + TVM_FFI_ICHECK_EQ(tkey, "rpc") + << "Constructor " << constructor_name << " to return an RPCModule"; serving_session_ = RPCModuleGetSession(mod); this->ReturnVoid(); } catch (const std::exception& e) { @@ -681,9 +689,9 @@ class RPCEndpoint::EventHandler : public support::Stream { private: RPCSession* GetServingSession() const { - ICHECK(serving_session_ != nullptr) + TVM_FFI_ICHECK(serving_session_ != nullptr) << "Need to call InitRemoteSession first before any further actions"; - ICHECK(!serving_session_->IsAsync() || async_server_mode_) + TVM_FFI_ICHECK(!serving_session_->IsAsync() || async_server_mode_) << "Cannot host an async session in a non-Event driven server"; return serving_session_.get(); @@ -691,7 +699,7 @@ class RPCEndpoint::EventHandler : public support::Stream { // Utility functions // Internal read function, update pending_request_bytes_ size_t Read(void* data, size_t size) final { - ICHECK_LE(size, pending_request_bytes_); + TVM_FFI_ICHECK_LE(size, pending_request_bytes_); reader_->Read(data, size); pending_request_bytes_ -= size; return size; @@ -721,8 +729,8 @@ class RPCEndpoint::EventHandler : public support::Stream { RPCCode RPCEndpoint::HandleUntilReturnEvent(bool client_mode, RPCSession::FEncodeReturn setreturn) { RPCCode code = RPCCode::kCallFunc; - CHECK(channel_) << "Expected connection to server " << name_ - << " to be active, but the connection was previously closed"; + TVM_FFI_ICHECK(channel_) << "Expected connection to server " << name_ + << " to be active, but the connection was previously closed"; while (code != RPCCode::kReturn && code != RPCCode::kShutdown && code != RPCCode::kCopyAck) { while (writer_.bytes_available() != 0) { writer_.ReadWithCallback( @@ -737,7 +745,7 @@ RPCCode RPCEndpoint::HandleUntilReturnEvent(bool client_mode, RPCSession::FEncod if (handler_->CanCleanShutdown()) { return RPCCode::kShutdown; } else { - LOG(FATAL) << "Channel closes before we get needed bytes"; + TVM_FFI_THROW(InternalError) << "Channel closes before we get needed bytes"; } } } @@ -776,10 +784,10 @@ void RPCEndpoint::Init() { handler_->SendPackedSeq(args.data(), args.size(), true); code = HandleUntilReturnEvent(true, [rv](ffi::PackedArgs args) { - ICHECK_EQ(args.size(), 1); + TVM_FFI_ICHECK_EQ(args.size(), 1); *rv = args[0]; }); - ICHECK(code == RPCCode::kReturn) << "code=" << static_cast(code); + TVM_FFI_ICHECK(code == RPCCode::kReturn) << "code=" << static_cast(code); }); } @@ -833,7 +841,7 @@ void RPCEndpoint::ServerLoop() { (*f)(); } ffi::Any rv; - ICHECK(HandleUntilReturnEvent(false, [](ffi::PackedArgs) {}) == RPCCode::kShutdown); + TVM_FFI_ICHECK(HandleUntilReturnEvent(false, [](ffi::PackedArgs) {}) == RPCCode::kShutdown); if (const auto f = tvm::ffi::Function::GetGlobal("tvm.rpc.server.shutdown")) { (*f)(); } @@ -852,7 +860,7 @@ int RPCEndpoint::ServerAsyncIOEventHandler(const std::string& in_bytes, int even [this](const void* data, size_t size) { return channel_->Send(data, size); }, writer_.bytes_available()); } - ICHECK(code != RPCCode::kReturn && code != RPCCode::kCopyAck); + TVM_FFI_ICHECK(code != RPCCode::kReturn && code != RPCCode::kCopyAck); // if the code is kShutdown, return 0 to indicate the server should exit if (code == RPCCode::kShutdown) return 0; // if the writer has bytes available, return 2 to indicate the server should send data @@ -880,7 +888,7 @@ void RPCEndpoint::InitRemoteSession(ffi::PackedArgs args) { handler_->SendPackedSeq(args.data(), args.size(), true); code = HandleUntilReturnEvent(true, [](ffi::PackedArgs args) {}); - ICHECK(code == RPCCode::kReturn) << "code=" << static_cast(code); + TVM_FFI_ICHECK(code == RPCCode::kReturn) << "code=" << static_cast(code); } // Get remote function with name @@ -902,7 +910,7 @@ void RPCEndpoint::CallFunc(RPCSession::PackedFuncHandle h, ffi::PackedArgs args, handler_->SendPackedSeq(args.data(), args.size(), true); code = HandleUntilReturnEvent(true, encode_return); - ICHECK(code == RPCCode::kReturn) << "code=" << RPCCodeToString(code); + TVM_FFI_ICHECK(code == RPCCode::kReturn) << "code=" << RPCCodeToString(code); } void RPCEndpoint::CopyToRemote(void* from_bytes, DLTensor* to, uint64_t nbytes) { @@ -910,7 +918,7 @@ void RPCEndpoint::CopyToRemote(void* from_bytes, DLTensor* to, uint64_t nbytes) RPCCode code = RPCCode::kCopyToRemote; uint64_t tensor_total_size_bytes = static_cast(ffi::GetDataSize(*to)); - ICHECK_LE(to->byte_offset + nbytes, tensor_total_size_bytes) + TVM_FFI_ICHECK_LE(to->byte_offset + nbytes, tensor_total_size_bytes) << "CopyToRemote: overflow in tensor size: (byte_offset=" << to->byte_offset << ", nbytes=" << nbytes << ", tensor_total_size=" << tensor_total_size_bytes << ")"; @@ -922,7 +930,7 @@ void RPCEndpoint::CopyToRemote(void* from_bytes, DLTensor* to, uint64_t nbytes) RPCReference::SendDLTensor(handler_, to); handler_->Write(nbytes); handler_->WriteArray(reinterpret_cast(from_bytes), nbytes); - ICHECK(HandleUntilReturnEvent(true, [](ffi::PackedArgs) {}) == RPCCode::kReturn); + TVM_FFI_ICHECK(HandleUntilReturnEvent(true, [](ffi::PackedArgs) {}) == RPCCode::kReturn); } void RPCEndpoint::CopyFromRemote(DLTensor* from, void* to_bytes, uint64_t nbytes) { @@ -930,7 +938,7 @@ void RPCEndpoint::CopyFromRemote(DLTensor* from, void* to_bytes, uint64_t nbytes RPCCode code = RPCCode::kCopyFromRemote; uint64_t tensor_total_size_bytes = static_cast(ffi::GetDataSize(*from)); - ICHECK_LE(from->byte_offset + nbytes, tensor_total_size_bytes) + TVM_FFI_ICHECK_LE(from->byte_offset + nbytes, tensor_total_size_bytes) << "CopyFromRemote: overflow in tensor size: (byte_offset=" << from->byte_offset << ", nbytes=" << nbytes << ", tensor_total_size=" << tensor_total_size_bytes << ")"; @@ -941,7 +949,7 @@ void RPCEndpoint::CopyFromRemote(DLTensor* from, void* to_bytes, uint64_t nbytes handler_->Write(code); RPCReference::SendDLTensor(handler_, from); handler_->Write(nbytes); - ICHECK(HandleUntilReturnEvent(true, [](ffi::PackedArgs) {}) == RPCCode::kCopyAck); + TVM_FFI_ICHECK(HandleUntilReturnEvent(true, [](ffi::PackedArgs) {}) == RPCCode::kCopyAck); handler_->ReadArray(reinterpret_cast(to_bytes), nbytes); handler_->FinishCopyAck(); @@ -1013,7 +1021,8 @@ void RPCCopyAmongRemote(RPCSession* handler, ffi::PackedArgs args, ffi::Any* rv) if (dev.device_type == kDLCPU) { dev = to->device; } else { - ICHECK(to->device.device_type == kDLCPU || to->device.device_type == from->device.device_type) + TVM_FFI_ICHECK(to->device.device_type == kDLCPU || + to->device.device_type == from->device.device_type) << "Can not copy across different dev types directly"; } handler->GetDeviceAPI(dev)->CopyDataFromTo(from, to, stream); @@ -1086,11 +1095,11 @@ void RPCEndpoint::EventHandler::HandleSyscall(RPCCode code) { SysCallHandler(RPCCopyAmongRemote); break; default: - LOG(FATAL) << "Unknown event " << static_cast(code); + TVM_FFI_THROW(InternalError) << "Unknown event " << static_cast(code); } if (state_ != kWaitForAsyncCallback) { - ICHECK_EQ(state_, kRecvPacketNumBytes); + TVM_FFI_ICHECK_EQ(state_, kRecvPacketNumBytes); } } @@ -1118,7 +1127,7 @@ class RPCClientSession : public RPCSession, public DeviceAPI { RPCCode code = RPCCode::kCopyToRemote; uint64_t overhead = RemoteCopyCalculatePacketOverheadSize(remote_to, code, nbytes); uint64_t rpc_max_size = GetRPCMaxTransferSize(); - ICHECK_GT(rpc_max_size, overhead) << "CopyToRemote: Invalid block size!"; + TVM_FFI_ICHECK_GT(rpc_max_size, overhead) << "CopyToRemote: Invalid block size!"; const uint64_t block_size = rpc_max_size - overhead; uint64_t block_count = 0; const uint64_t num_blocks = nbytes / block_size; @@ -1144,7 +1153,7 @@ class RPCClientSession : public RPCSession, public DeviceAPI { RPCCode code = RPCCode::kCopyFromRemote; uint64_t overhead = RemoteCopyCalculatePacketOverheadSize(remote_from, code, nbytes); uint64_t rpc_max_size = GetRPCMaxTransferSize(); - ICHECK_GT(rpc_max_size, overhead) << "CopyFromRemote: Invalid block size!"; + TVM_FFI_ICHECK_GT(rpc_max_size, overhead) << "CopyFromRemote: Invalid block size!"; const uint64_t block_size = rpc_max_size - overhead; uint64_t block_count = 0; const uint64_t num_blocks = nbytes / block_size; @@ -1253,7 +1262,7 @@ class RPCClientSession : public RPCSession, public DeviceAPI { // Use args[1] as return value, args[0] is tcode // Look at RPCWrappedFunc in src/runtime/rpc/rpc_module.cc rpc_chunk_max_size_bytes_ = args[1].cast(); - ICHECK_GT(rpc_chunk_max_size_bytes_, 0) + TVM_FFI_ICHECK_GT(rpc_chunk_max_size_bytes_, 0) << "RPC max transfer size is <= 0! (remote value = " << rpc_chunk_max_size_bytes_ << ")"; }); diff --git a/src/runtime/rpc/rpc_event_impl.cc b/src/runtime/rpc/rpc_event_impl.cc index 4eefb2b2b978..e0e747283c0a 100644 --- a/src/runtime/rpc/rpc_event_impl.cc +++ b/src/runtime/rpc/rpc_event_impl.cc @@ -34,8 +34,9 @@ namespace runtime { ffi::Function CreateEventDrivenServer(ffi::Function fsend, std::string name, std::string remote_key) { - static ffi::Function frecv( - [](ffi::PackedArgs args, ffi::Any* rv) { LOG(FATAL) << "Do not allow explicit receive"; }); + static ffi::Function frecv([](ffi::PackedArgs args, ffi::Any* rv) { + TVM_FFI_THROW(InternalError) << "Do not allow explicit receive"; + }); auto ch = std::make_unique(fsend, frecv); std::shared_ptr sess = RPCEndpoint::Create(std::move(ch), name, remote_key); diff --git a/src/runtime/rpc/rpc_local_session.cc b/src/runtime/rpc/rpc_local_session.cc index 2cfeacfcd71f..4a534dded860 100644 --- a/src/runtime/rpc/rpc_local_session.cc +++ b/src/runtime/rpc/rpc_local_session.cc @@ -106,7 +106,7 @@ void LocalSession::CallFunc(RPCSession::PackedFuncHandle func, ffi::PackedArgs a } void LocalSession::CopyToRemote(void* from_bytes, DLTensor* to, uint64_t nbytes) { - ICHECK_EQ(nbytes, GetDataSize(*to)); + TVM_FFI_ICHECK_EQ(nbytes, GetDataSize(*to)); DLTensor from; from.data = from_bytes; from.device = {kDLCPU, 0}; @@ -123,7 +123,7 @@ void LocalSession::CopyToRemote(void* from_bytes, DLTensor* to, uint64_t nbytes) } void LocalSession::CopyFromRemote(DLTensor* from, void* to_bytes, uint64_t nbytes) { - ICHECK_EQ(nbytes, ffi::GetDataSize(*from)); + TVM_FFI_ICHECK_EQ(nbytes, ffi::GetDataSize(*from)); DLTensor to; to.data = to_bytes; to.device = {kDLCPU, 0}; diff --git a/src/runtime/rpc/rpc_module.cc b/src/runtime/rpc/rpc_module.cc index a90c69c63c8b..13f2f0bb7c07 100644 --- a/src/runtime/rpc/rpc_module.cc +++ b/src/runtime/rpc/rpc_module.cc @@ -53,7 +53,7 @@ namespace runtime { Tensor TensorFromRemoteOpaqueHandle(std::shared_ptr sess, void* handle, DLTensor* template_tensor, Device dev, void* remote_tensor_handle) { - ICHECK_EQ(sess->table_index(), GetRPCSessionIndex(dev)) + TVM_FFI_ICHECK_EQ(sess->table_index(), GetRPCSessionIndex(dev)) << "The Device given does not belong to the given session"; class RemoteSpaceAlloc { public: @@ -162,8 +162,8 @@ class RPCWrappedFunc : public Object { // remove a remote session mask Device RemoveSessMask(Device dev) const { - ICHECK(IsRPCSessionDevice(dev)) << "Can not pass in local device"; - ICHECK_EQ(GetRPCSessionIndex(dev), sess_->table_index()) + TVM_FFI_ICHECK(IsRPCSessionDevice(dev)) << "Can not pass in local device"; + TVM_FFI_ICHECK_EQ(GetRPCSessionIndex(dev), sess_->table_index()) << "Can not pass in device with a different remote session"; return RemoveRPCSessionMask(dev); } @@ -209,8 +209,8 @@ class RPCModuleNode final : public ffi::ModuleObj { int cache_flush_bytes, const std::string& f_preproc_name) { InitRemoteFunc(&remote_get_time_evaluator_, "runtime.RPCTimeEvaluator"); // Remove session mask because we pass dev by parts. - ICHECK_EQ(GetRPCSessionIndex(dev), sess_->table_index()) - << "ValueError: Need to pass the matched remote device to RPCModule.GetTimeEvaluator"; + TVM_FFI_CHECK_EQ(GetRPCSessionIndex(dev), sess_->table_index(), ValueError) + << "Need to pass the matched remote device to RPCModule.GetTimeEvaluator"; dev = RemoveRPCSessionMask(dev); if (module_handle_ != nullptr) { @@ -245,7 +245,7 @@ class RPCModuleNode final : public ffi::ModuleObj { void InitRemoteFunc(FType* func, const std::string& name) { if (*func != nullptr) return; RPCSession::PackedFuncHandle handle = sess_->GetFunction(name); - ICHECK(handle != nullptr) << "Cannot found remote function " << name; + TVM_FFI_ICHECK(handle != nullptr) << "Cannot found remote function " << name; *func = WrapRemoteFunc(handle); } @@ -277,14 +277,14 @@ void* RPCWrappedFunc::UnwrapRemoteValueToHandle(const AnyView& arg) const { if (arg.type_index() == ffi::TypeIndex::kTVMFFIModule) { ffi::Module mod = arg.cast(); std::string tkey = mod->kind(); - ICHECK_EQ(tkey, "rpc") << "ValueError: Cannot pass a non-RPC module to remote"; + TVM_FFI_CHECK_EQ(tkey, "rpc", ValueError) << "Cannot pass a non-RPC module to remote"; auto* rmod = static_cast(mod.operator->()); - ICHECK(rmod->sess() == sess_) - << "ValueError: Cannot pass in module into a different remote session"; + TVM_FFI_CHECK(rmod->sess() == sess_, ValueError) + << "Cannot pass in module into a different remote session"; return rmod->module_handle(); } else { - LOG(FATAL) << "ValueError: Cannot pass type " << arg.GetTypeKey() - << " as an argument to the remote"; + TVM_FFI_THROW(ValueError) << "Cannot pass type " << arg.GetTypeKey() + << " as an argument to the remote"; return nullptr; } } @@ -295,19 +295,19 @@ void RPCWrappedFunc::WrapRemoteReturnToValue(ffi::PackedArgs args, ffi::Any* rv) *rv = nullptr; return; } else if (type_index == ffi::TypeIndex::kTVMFFIFunction) { - ICHECK_EQ(args.size(), 2); + TVM_FFI_ICHECK_EQ(args.size(), 2); void* handle = args[1].cast(); auto wf = std::make_shared(handle, sess_); *rv = ffi::Function( [wf](ffi::PackedArgs args, ffi::Any* rv) { return wf->operator()(args, rv); }); } else if (type_index == ffi::TypeIndex::kTVMFFIModule) { - ICHECK_EQ(args.size(), 2); + TVM_FFI_ICHECK_EQ(args.size(), 2); void* handle = args[1].cast(); auto n = ffi::make_object(handle, sess_); *rv = ffi::Module(n); } else if (type_index == ffi::TypeIndex::kTVMFFITensor || type_index == ffi::TypeIndex::kTVMFFIDLTensorPtr) { - ICHECK_EQ(args.size(), 3); + TVM_FFI_ICHECK_EQ(args.size(), 3); auto tensor = args[1].cast(); void* nd_handle = args[2].cast(); *rv = TensorFromRemoteOpaqueHandle(sess_, tensor->data, tensor, @@ -317,15 +317,15 @@ void RPCWrappedFunc::WrapRemoteReturnToValue(ffi::PackedArgs args, ffi::Any* rv) type_index == ffi::TypeIndex::kTVMFFIStr || type_index == ffi::TypeIndex::kTVMFFISmallStr || type_index == ffi::TypeIndex::kTVMFFISmallBytes) { - ICHECK_EQ(args.size(), 2); + TVM_FFI_ICHECK_EQ(args.size(), 2); *rv = args[1]; } else if (type_index >= ffi::TypeIndex::kTVMFFIStaticObjectBegin) { - ICHECK_EQ(args.size(), 2); + TVM_FFI_ICHECK_EQ(args.size(), 2); void* handle = args[1].cast(); auto n = ffi::make_object(handle, sess_); *rv = ObjectRef(n); } else { - ICHECK_EQ(args.size(), 2); + TVM_FFI_ICHECK_EQ(args.size(), 2); *rv = args[1]; } } @@ -338,7 +338,7 @@ ffi::Module CreateRPCSessionModule(std::shared_ptr sess) { std::shared_ptr RPCModuleGetSession(ffi::Module mod) { std::string tkey = mod->kind(); - ICHECK_EQ(tkey, "rpc") << "ValueError: Cannot pass a non-RPC module to remote"; + TVM_FFI_CHECK_EQ(tkey, "rpc", ValueError) << "Cannot pass a non-RPC module to remote"; auto* rmod = static_cast(mod.operator->()); return rmod->sess(); } @@ -416,23 +416,25 @@ TVM_FFI_STATIC_INIT_BLOCK() { ffi::Function f_preproc; if (!f_preproc_name.empty()) { auto pf_preproc = tvm::ffi::Function::GetGlobal(f_preproc_name); - ICHECK(pf_preproc.has_value()) + TVM_FFI_ICHECK(pf_preproc.has_value()) << "Cannot find " << f_preproc_name << " in the global function"; f_preproc = *pf_preproc; } ffi::Optional pf = m->GetFunction(name); - CHECK(pf.has_value()) << "Cannot find " << name << "` in the global registry"; + TVM_FFI_ICHECK(pf.has_value()) + << "Cannot find " << name << "` in the global registry"; return profiling::WrapTimeEvaluator( *pf, dev, number, repeat, min_repeat_ms, limit_zero_time_iterations, cooldown_interval_ms, repeats_to_cooldown, cache_flush_bytes, f_preproc); } } else { auto pf = tvm::ffi::Function::GetGlobal(name); - ICHECK(pf.has_value()) << "Cannot find " << name << " in the global function"; + TVM_FFI_ICHECK(pf.has_value()) + << "Cannot find " << name << " in the global function"; ffi::Function f_preproc; if (!f_preproc_name.empty()) { auto pf_preproc = tvm::ffi::Function::GetGlobal(f_preproc_name); - ICHECK(pf_preproc.has_value()) + TVM_FFI_ICHECK(pf_preproc.has_value()) << "Cannot find " << f_preproc_name << " in the global function"; f_preproc = *pf_preproc; } @@ -464,20 +466,20 @@ TVM_FFI_STATIC_INIT_BLOCK() { .def("rpc.LoadRemoteModule", [](ffi::Module sess, std::string name) { std::string tkey = sess->kind(); - ICHECK_EQ(tkey, "rpc"); + TVM_FFI_ICHECK_EQ(tkey, "rpc"); return static_cast(sess.operator->())->LoadModule(name); }) .def("rpc.ImportRemoteModule", [](ffi::Module parent, ffi::Module child) { std::string tkey = parent->kind(); - ICHECK_EQ(tkey, "rpc"); + TVM_FFI_ICHECK_EQ(tkey, "rpc"); static_cast(parent.operator->())->ImportModule(child); }) .def_packed("rpc.SessTableIndex", [](ffi::PackedArgs args, ffi::Any* rv) { ffi::Module m = args[0].cast(); std::string tkey = m->kind(); - ICHECK_EQ(tkey, "rpc"); + TVM_FFI_ICHECK_EQ(tkey, "rpc"); *rv = static_cast(m.operator->())->sess()->table_index(); }) .def("tvm.rpc.TensorFromRemoteOpaqueHandle", diff --git a/src/runtime/rpc/rpc_pipe_impl.cc b/src/runtime/rpc/rpc_pipe_impl.cc index fd02e747bf5d..d3986fde7eb3 100644 --- a/src/runtime/rpc/rpc_pipe_impl.cc +++ b/src/runtime/rpc/rpc_pipe_impl.cc @@ -51,7 +51,7 @@ class PipeChannel final : public RPCChannel { size_t Send(const void* data, size_t size) final { ssize_t n = write(writefd_, data, size); if (n == -1) { - LOG(FATAL) << "Pipe write error"; + TVM_FFI_THROW(InternalError) << "Pipe write error"; } return static_cast(n); } @@ -59,7 +59,7 @@ class PipeChannel final : public RPCChannel { size_t Recv(void* data, size_t size) final { ssize_t n = read(readfd_, data, size); if (n == -1) { - LOG(FATAL) << "Pipe read error"; + TVM_FFI_THROW(InternalError) << "Pipe read error"; } return static_cast(n); } @@ -79,8 +79,8 @@ class PipeChannel final : public RPCChannel { ffi::Module CreatePipeClient(std::vector cmd) { int parent2child[2]; int child2parent[2]; - ICHECK_EQ(pipe(parent2child), 0); - ICHECK_EQ(pipe(child2parent), 0); + TVM_FFI_ICHECK_EQ(pipe(parent2child), 0); + TVM_FFI_ICHECK_EQ(pipe(child2parent), 0); int parent_read = child2parent[0]; int parent_write = parent2child[1]; diff --git a/src/runtime/rpc/rpc_server_env.cc b/src/runtime/rpc/rpc_server_env.cc index c8e7a4ee81c9..a51d98b17f93 100644 --- a/src/runtime/rpc/rpc_server_env.cc +++ b/src/runtime/rpc/rpc_server_env.cc @@ -32,7 +32,7 @@ namespace runtime { std::string RPCGetPath(const std::string& name) { // do live lookup everytime as workpath can change. const auto f = tvm::ffi::Function::GetGlobal("tvm.rpc.server.workpath"); - ICHECK(f.has_value()) << "require tvm.rpc.server.workpath"; + TVM_FFI_ICHECK(f.has_value()) << "require tvm.rpc.server.workpath"; return (*f)(name).cast(); } diff --git a/src/runtime/rpc/rpc_session.cc b/src/runtime/rpc/rpc_session.cc index 1fee1424ea22..349568062eb1 100644 --- a/src/runtime/rpc/rpc_session.cc +++ b/src/runtime/rpc/rpc_session.cc @@ -96,7 +96,7 @@ class RPCSessTable { } // Get session from table std::shared_ptr Get(int index) { - ICHECK(index >= 0 && index < kMaxRPCSession); + TVM_FFI_ICHECK(index >= 0 && index < kMaxRPCSession); return tbl_[index].lock(); } // Insert session into table. @@ -108,7 +108,7 @@ class RPCSessTable { return i; } } - LOG(FATAL) << "maximum number of RPC session reached"; + TVM_FFI_THROW(InternalError) << "maximum number of RPC session reached"; } private: @@ -124,7 +124,7 @@ std::shared_ptr RPCSession::Get(int table_index) { } void RPCSession::InsertToSessionTable(std::shared_ptr sess) { - ICHECK_EQ(sess->table_index_, 0); + TVM_FFI_ICHECK_EQ(sess->table_index_, 0); sess->table_index_ = RPCSessTable::Global()->Insert(sess); } diff --git a/src/runtime/rpc/rpc_socket_impl.cc b/src/runtime/rpc/rpc_socket_impl.cc index eaff1f539f0e..aa75145ce2fc 100644 --- a/src/runtime/rpc/rpc_socket_impl.cc +++ b/src/runtime/rpc/rpc_socket_impl.cc @@ -72,32 +72,34 @@ std::shared_ptr RPCConnect(std::string url, int port, std::string k support::TCPSocket sock; support::SockAddr addr(url.c_str(), port); sock.Create(addr.ss_family()); - ICHECK(sock.Connect(addr)) << "Connect to " << addr.AsString() << " failed"; + TVM_FFI_ICHECK(sock.Connect(addr)) << "Connect to " << addr.AsString() << " failed"; // hand shake std::ostringstream os; int code = kRPCMagic; int keylen = static_cast(key.length()); - ICHECK_EQ(sock.SendAll(&code, sizeof(code)), sizeof(code)); - ICHECK_EQ(sock.SendAll(&keylen, sizeof(keylen)), sizeof(keylen)); + TVM_FFI_ICHECK_EQ(sock.SendAll(&code, sizeof(code)), sizeof(code)); + TVM_FFI_ICHECK_EQ(sock.SendAll(&keylen, sizeof(keylen)), sizeof(keylen)); if (keylen != 0) { - ICHECK_EQ(sock.SendAll(key.c_str(), keylen), keylen); + TVM_FFI_ICHECK_EQ(sock.SendAll(key.c_str(), keylen), keylen); } - ICHECK_EQ(sock.RecvAll(&code, sizeof(code)), sizeof(code)); + TVM_FFI_ICHECK_EQ(sock.RecvAll(&code, sizeof(code)), sizeof(code)); if (code == kRPCMagic + 2) { sock.Close(); - LOG(FATAL) << "URL " << url << ":" << port << " cannot find server that matches key=" << key; + TVM_FFI_THROW(InternalError) << "URL " << url << ":" << port + << " cannot find server that matches key=" << key; } else if (code == kRPCMagic + 1) { sock.Close(); - LOG(FATAL) << "URL " << url << ":" << port << " server already have key=" << key; + TVM_FFI_THROW(InternalError) << "URL " << url << ":" << port + << " server already have key=" << key; } else if (code != kRPCMagic) { sock.Close(); - LOG(FATAL) << "URL " << url << ":" << port << " is not TVM RPC server"; + TVM_FFI_THROW(InternalError) << "URL " << url << ":" << port << " is not TVM RPC server"; } - ICHECK_EQ(sock.RecvAll(&keylen, sizeof(keylen)), sizeof(keylen)); + TVM_FFI_ICHECK_EQ(sock.RecvAll(&keylen, sizeof(keylen)), sizeof(keylen)); std::string remote_key; if (keylen != 0) { remote_key.resize(keylen); - ICHECK_EQ(sock.RecvAll(&remote_key[0], keylen), keylen); + TVM_FFI_ICHECK_EQ(sock.RecvAll(&remote_key[0], keylen), keylen); } std::unique_ptr channel = std::make_unique(sock); diff --git a/src/runtime/static_library.cc b/src/runtime/static_library.cc index 621fcc506bac..f53e12a15dc1 100644 --- a/src/runtime/static_library.cc +++ b/src/runtime/static_library.cc @@ -74,12 +74,12 @@ class StaticLibraryNode final : public ffi::ModuleObj { auto n = ffi::make_object(); // load data std::string data; - ICHECK(stream.Read(&data)) << "Loading data failed"; + TVM_FFI_ICHECK(stream.Read(&data)) << "Loading data failed"; n->data_ = std::move(data); // load func names std::vector func_names; - ICHECK(stream.Read(&func_names)) << "Loading func names failed"; + TVM_FFI_ICHECK(stream.Read(&func_names)) << "Loading func names failed"; for (auto func_name : func_names) n->func_names_.push_back(ffi::String(func_name)); return ffi::Module(n); diff --git a/src/runtime/tensor.cc b/src/runtime/tensor.cc index 4ef744452c3c..d4fe1772b978 100644 --- a/src/runtime/tensor.cc +++ b/src/runtime/tensor.cc @@ -34,9 +34,9 @@ namespace tvm { namespace runtime { inline void VerifyDataType(DLDataType dtype) { - ICHECK_GE(dtype.lanes, 1); + TVM_FFI_ICHECK_GE(dtype.lanes, 1); if (dtype.code == kDLFloat) { - ICHECK_EQ(dtype.bits % 8, 0); + TVM_FFI_ICHECK_EQ(dtype.bits % 8, 0); } else { // allow uint1 as a special flag for bool. if (dtype.bits == 1 && dtype.code == kDLUInt) return; @@ -54,15 +54,16 @@ inline void VerifyDataType(DLDataType dtype) { else if (dtype.bits == 4 && dtype.code == DataType::kFloat4_e2m1fn) return; else - ICHECK_EQ(dtype.bits % 8, 0); + TVM_FFI_ICHECK_EQ(dtype.bits % 8, 0); } - ICHECK_EQ(dtype.bits & (dtype.bits - 1), 0); + TVM_FFI_ICHECK_EQ(dtype.bits & (dtype.bits - 1), 0); } void TensorCopyFromBytes(DLTensor* handle, const void* data, size_t nbytes) { size_t arr_size = GetDataSize(*handle); - ICHECK_EQ(arr_size, nbytes) << "TensorCopyFromBytes: size mismatch"; - ICHECK(IsContiguous(*handle)) << "TensorCopyFromBytes only support contiguous array for now"; + TVM_FFI_ICHECK_EQ(arr_size, nbytes) << "TensorCopyFromBytes: size mismatch"; + TVM_FFI_ICHECK(IsContiguous(*handle)) + << "TensorCopyFromBytes only support contiguous array for now"; DLTensor from; from.data = const_cast(data); @@ -80,8 +81,9 @@ void TensorCopyFromBytes(DLTensor* handle, const void* data, size_t nbytes) { void Tensor::CopyToBytes(const DLTensor* handle, void* data, size_t nbytes, TVMStreamHandle stream) { size_t arr_size = GetDataSize(*handle); - ICHECK_EQ(arr_size, nbytes) << "ArrayCopyToBytes: size mismatch"; - ICHECK(ffi::IsContiguous(*handle)) << "ArrayCopyToBytes only support contiguous array for now"; + TVM_FFI_ICHECK_EQ(arr_size, nbytes) << "ArrayCopyToBytes: size mismatch"; + TVM_FFI_ICHECK(ffi::IsContiguous(*handle)) + << "ArrayCopyToBytes only support contiguous array for now"; DLTensor to; to.data = const_cast(data); @@ -100,8 +102,9 @@ void Tensor::CopyToBytes(const DLTensor* handle, void* data, size_t nbytes, void Tensor::CopyFromBytes(const DLTensor* handle, void* data, size_t nbytes, TVMStreamHandle stream) { size_t arr_size = GetDataSize(*handle); - ICHECK_EQ(arr_size, nbytes) << "ArrayCopyToBytes: size mismatch"; - ICHECK(ffi::IsContiguous(*handle)) << "ArrayCopyToBytes only support contiguous array for now"; + TVM_FFI_ICHECK_EQ(arr_size, nbytes) << "ArrayCopyToBytes: size mismatch"; + TVM_FFI_ICHECK(ffi::IsContiguous(*handle)) + << "ArrayCopyToBytes only support contiguous array for now"; DLTensor from; from.data = const_cast(data); @@ -133,10 +136,10 @@ Tensor Tensor::Empty(ffi::Shape shape, DLDataType dtype, Device dev, } Tensor Tensor::CreateView(ffi::Shape shape, DLDataType dtype, uint64_t relative_byte_offset) const { - ICHECK(data_ != nullptr); + TVM_FFI_ICHECK(data_ != nullptr); const DLTensor& orig = *get_mutable(); - CHECK(IsContiguous()) << [&orig]() { + TVM_FFI_ICHECK(IsContiguous()) << [&orig]() { std::stringstream ss; ss << "Can only create view for compact tensor, but found strides "; @@ -159,8 +162,7 @@ Tensor Tensor::CreateView(ffi::Shape shape, DLDataType dtype, uint64_t relative_ const auto& curr_dl_tensor = *get_mutable(); size_t curr_size = GetDataSize(curr_dl_tensor); size_t view_size = ffi::GetDataSize(shape.Product(), dtype); - CHECK_LE(relative_byte_offset + view_size, curr_size) - << "ValueError: " + TVM_FFI_CHECK_LE(relative_byte_offset + view_size, curr_size, ValueError) << "View with shape " << shape << " and datatype " << dtype << " would have a size of " << view_size << " bytes. " << "This would occupy bytes " << relative_byte_offset << " <= i_byte < " @@ -190,19 +192,19 @@ Tensor Tensor::CreateView(ffi::Shape shape, DLDataType dtype, uint64_t relative_ } void Tensor::CopyToBytes(void* data, size_t nbytes) const { - ICHECK(data != nullptr); - ICHECK(data_ != nullptr); + TVM_FFI_ICHECK(data != nullptr); + TVM_FFI_ICHECK(data_ != nullptr); Tensor::CopyToBytes(get_mutable(), data, nbytes); } void Tensor::CopyFromBytes(const void* data, size_t nbytes) { - ICHECK(data != nullptr); - ICHECK(data_ != nullptr); + TVM_FFI_ICHECK(data != nullptr); + TVM_FFI_ICHECK(data_ != nullptr); TensorCopyFromBytes(get_mutable(), data, nbytes); } Tensor Tensor::CopyTo(const Device& dev, ffi::Optional mem_scope) const { - ICHECK(data_ != nullptr); + TVM_FFI_ICHECK(data_ != nullptr); const DLTensor* dptr = operator->(); Tensor ret = Empty(ffi::Shape(dptr->shape, dptr->shape + dptr->ndim), dptr->dtype, dev, mem_scope); @@ -215,12 +217,13 @@ Tensor Tensor::CopyTo(const Device& dev, ffi::Optional mem_scope) c void Tensor::CopyFromTo(const DLTensor* from, DLTensor* to, TVMStreamHandle stream) { size_t from_size = GetDataSize(*from); size_t to_size = GetDataSize(*to); - ICHECK_EQ(from_size, to_size) << "TVMTensorCopyFromTo: The size in bytes must exactly match."; + TVM_FFI_ICHECK_EQ(from_size, to_size) + << "TVMTensorCopyFromTo: The size in bytes must exactly match."; - ICHECK(from->device.device_type == to->device.device_type || from->device.device_type == kDLCPU || - to->device.device_type == kDLCPU || from->device.device_type == kDLCUDAHost || - to->device.device_type == kDLCUDAHost || from->device.device_type == kDLROCMHost || - to->device.device_type == kDLROCMHost) + TVM_FFI_ICHECK(from->device.device_type == to->device.device_type || + from->device.device_type == kDLCPU || to->device.device_type == kDLCPU || + from->device.device_type == kDLCUDAHost || to->device.device_type == kDLCUDAHost || + from->device.device_type == kDLROCMHost || to->device.device_type == kDLROCMHost) << "Can not copy across different device types directly. From device type: " << from->device.device_type << " to device type: " << to->device.device_type; diff --git a/src/runtime/texture.h b/src/runtime/texture.h index e2b6d603ed50..818885ff2173 100644 --- a/src/runtime/texture.h +++ b/src/runtime/texture.h @@ -63,7 +63,8 @@ inline size_t DefaultTextureLayoutSeparator(size_t shape_rank, } else if (convention == "global.texture-nhwc") { separator = 2; } else { - LOG(FATAL) << "Encountered unknown texture lowering convention: " << convention; + TVM_FFI_THROW(InternalError) << "Encountered unknown texture lowering convention: " + << convention; } return separator; } @@ -76,7 +77,7 @@ inline size_t DefaultTextureLayoutSeparator(size_t shape_rank, */ template Texture2DShape ApplyTexture2DFlattening(const S& shape, size_t rank, size_t axis) { - ICHECK(axis < rank) + TVM_FFI_ICHECK(axis < rank) << "Number of axes to flatten into rows must be less than shape rank for 2d flattening"; Texture2DShape texture{1, 1, 1, shape[rank - 1]}; for (size_t i = 0; i < rank - 1; i++) { @@ -129,7 +130,7 @@ inline DataType GetChannelType(size_t channel_size) { else if (channel_size == 64) return DataType::Float(16, 4); - LOG(FATAL) << "Unsupported Channel Size: " << channel_size; + TVM_FFI_THROW(InternalError) << "Unsupported Channel Size: " << channel_size; } } // namespace runtime diff --git a/src/runtime/thread_pool.cc b/src/runtime/thread_pool.cc index ac605ef86da5..0a46336818b8 100644 --- a/src/runtime/thread_pool.cc +++ b/src/runtime/thread_pool.cc @@ -195,7 +195,7 @@ class SpscTaskQueue { } const uint32_t head = head_.load(std::memory_order_relaxed); // sanity check if the queue is empty - ICHECK(tail_.load(std::memory_order_acquire) != head); + TVM_FFI_ICHECK(tail_.load(std::memory_order_acquire) != head); *output = buffer_[head]; head_.store((head + 1) % kRingSize, std::memory_order_release); return true; @@ -290,13 +290,13 @@ class ThreadPool { int Launch(FTVMParallelLambda flambda, void* cdata, int num_task, int need_sync) { ParallelLauncher* launcher = ParallelLauncher::ThreadLocal(); - ICHECK(!launcher->is_worker) + TVM_FFI_ICHECK(!launcher->is_worker) << "Cannot launch parallel job inside worker, consider fuse then parallel"; if (num_task == 0) { num_task = num_workers_used_; } if (need_sync != 0) { - ICHECK_LE(num_task, num_workers_used_) + TVM_FFI_ICHECK_LE(num_task, num_workers_used_) << "Request parallel sync task larger than number of threads used " << " workers=" << num_workers_used_ << " request=" << num_task; } @@ -361,7 +361,7 @@ class ThreadPool { // TODO(tulloch): should we make this configurable via standard APIs? static size_t spin_count = GetSpinCount(); while (queue->Pop(&task, spin_count)) { - ICHECK(task.launcher != nullptr); + TVM_FFI_ICHECK(task.launcher != nullptr); TVMParallelGroupEnv* penv = &(task.launcher->env); void* cdata = task.launcher->cdata; if ((*task.launcher->flambda)(task.task_id, penv, cdata) == 0) { @@ -396,7 +396,7 @@ TVM_FFI_STATIC_INIT_BLOCK() { if (args.size() >= 3) { auto cpu_array = args[2].cast>(); for (auto cpu : cpu_array) { - ICHECK(IsNumber(cpu)) + TVM_FFI_ICHECK(IsNumber(cpu)) << "The CPU core information '" << cpu << "' is not a number."; cpus.push_back(std::stoi(cpu)); } diff --git a/src/runtime/thread_storage_scope.h b/src/runtime/thread_storage_scope.h index 9d6e71962271..313e4cfe484c 100644 --- a/src/runtime/thread_storage_scope.h +++ b/src/runtime/thread_storage_scope.h @@ -86,7 +86,7 @@ inline StorageRank DefaultStorageRank(int thread_scope_rank) { case 1: return StorageRank::kLocal; default: { - LOG(FATAL) << "unknown rank"; + TVM_FFI_THROW(InternalError) << "unknown rank"; } } } @@ -130,7 +130,7 @@ struct StorageScope { case StorageRank::kMetalSimdGroup: return "metal.simdgroup" + tag; default: - LOG(FATAL) << "unknown storage scope"; + TVM_FFI_THROW(InternalError) << "unknown storage scope"; return ""; } } @@ -183,7 +183,7 @@ struct StorageScope { r.rank = StorageRank::kMetalSimdGroup; r.tag = s.substr(15, std::string::npos); } else { - LOG(FATAL) << "unknown storage scope " << s; + TVM_FFI_THROW(InternalError) << "unknown storage scope " << s; } return r; } @@ -213,7 +213,7 @@ struct ThreadScope { r.rank = 1; r.dim_index = static_cast(s[10] - 'x'); } else { - LOG(FATAL) << "Unknown threadscope " << s; + TVM_FFI_THROW(InternalError) << "Unknown threadscope " << s; } return r; } @@ -245,7 +245,7 @@ class LaunchParamConfig { for (size_t i = 0; i < launch_param_tags.size(); ++i) { std::string tag(launch_param_tags[i]); if (tag == launch_param::kUseDynamicSharedMemoryTag) { - ICHECK_EQ(i, launch_param_tags.size() - 1) + TVM_FFI_ICHECK_EQ(i, launch_param_tags.size() - 1) << "kUseDynamicSharedMemoryTag should be the last tag in launch_param_tags."; use_dyn_shared_memory_ = true; } else if (tag == launch_param::kUseProgramaticDependentLaunch) { diff --git a/src/runtime/threading_backend.cc b/src/runtime/threading_backend.cc index c4f6b3e17777..549e6fb9dd4d 100644 --- a/src/runtime/threading_backend.cc +++ b/src/runtime/threading_backend.cc @@ -66,17 +66,17 @@ class QuRTThread { qurt_thread_attr_t attr; char name[32]; int ret = posix_memalign(&stack_, HEXAGON_STACK_ALIGNMENT, HEXAGON_STACK_SIZE); - CHECK_EQ(ret, 0); + TVM_FFI_ICHECK_EQ(ret, 0); // When a std::function<> is cast to bool, // it indicates whether it stores a callable target - CHECK_EQ((bool)worker_callback_, true); + TVM_FFI_ICHECK_EQ((bool)worker_callback_, true); qurt_thread_attr_init(&attr); qurt_thread_attr_set_stack_size(&attr, HEXAGON_STACK_SIZE); qurt_thread_attr_set_stack_addr(&attr, stack_); snprintf(name, sizeof(name), "worker %d", id++); qurt_thread_attr_set_name(&attr, name); ret = qurt_thread_create(&thread_, &attr, (void (*)(void*))RunFunction, this); - CHECK_EQ(ret, QURT_EOK); + TVM_FFI_ICHECK_EQ(ret, QURT_EOK); } QuRTThread(QuRTThread&& other) : thread_(other.thread_), @@ -147,7 +147,7 @@ class ThreadGroup::Impl { public: Impl(int num_workers, std::function worker_callback, bool exclude_worker0) : num_workers_(num_workers) { - ICHECK_GE(num_workers, 1) << "Requested a non-positive number of worker threads."; + TVM_FFI_ICHECK_GE(num_workers, 1) << "Requested a non-positive number of worker threads."; for (int i = exclude_worker0; i < num_workers_; ++i) { threads_.emplace_back([worker_callback, i] { worker_callback(i); }); } @@ -225,7 +225,7 @@ class ThreadGroup::Impl { break; } } else { - ICHECK_GE(sorted_order_.size(), num_workers_); + TVM_FFI_ICHECK_GE(sorted_order_.size(), num_workers_); switch (mode) { case kSpecifyThreadShareAllCore: for (unsigned i = 0; i < threads_.size(); ++i) { diff --git a/src/runtime/vm/attn_backend.cc b/src/runtime/vm/attn_backend.cc index 13e151ecd202..8bb57103d305 100644 --- a/src/runtime/vm/attn_backend.cc +++ b/src/runtime/vm/attn_backend.cc @@ -32,18 +32,18 @@ std::unique_ptr ConvertPagedPrefillFunc(ffi::Array a } ffi::String backend_name = args[0].cast(); if (backend_name == "tir") { - CHECK_EQ(args.size(), 2); + TVM_FFI_ICHECK_EQ(args.size(), 2); ffi::Function attn_func = args[1].cast(); return std::make_unique(std::move(attn_func), attn_kind); } if (backend_name == "flashinfer") { - CHECK_EQ(args.size(), 3); + TVM_FFI_ICHECK_EQ(args.size(), 3); ffi::Function attn_func = args[1].cast(); ffi::Function plan_func = args[2].cast(); return std::make_unique(std::move(attn_func), std::move(plan_func), attn_kind); } - LOG(FATAL) << "Cannot reach here"; + TVM_FFI_THROW(InternalError) << "Cannot reach here"; throw; } @@ -54,12 +54,12 @@ std::unique_ptr ConvertRaggedPrefillFunc(ffi::Array } ffi::String backend_name = args[0].cast(); if (backend_name == "tir") { - CHECK_EQ(args.size(), 2); + TVM_FFI_ICHECK_EQ(args.size(), 2); ffi::Function attn_func = args[1].cast(); return std::make_unique(std::move(attn_func), attn_kind); } if (backend_name == "flashinfer") { - CHECK(args.size() == 3 || args.size() == 5); + TVM_FFI_ICHECK(args.size() == 3 || args.size() == 5); ffi::Function attn_func = args[1].cast(); ffi::Function plan_func = args[2].cast(); int64_t qk_head_dim_override = -1; @@ -72,7 +72,7 @@ std::unique_ptr ConvertRaggedPrefillFunc(ffi::Array attn_kind, qk_head_dim_override, v_head_dim_override); } - LOG(FATAL) << "Cannot reach here"; + TVM_FFI_THROW(InternalError) << "Cannot reach here"; throw; } @@ -83,18 +83,18 @@ std::unique_ptr ConvertPagedDecodeFunc(ffi::Array arg } ffi::String backend_name = args[0].cast(); if (backend_name == "tir") { - CHECK_EQ(args.size(), 2); + TVM_FFI_ICHECK_EQ(args.size(), 2); ffi::Function attn_func = args[1].cast(); return std::make_unique(std::move(attn_func), attn_kind); } if (backend_name == "flashinfer") { - CHECK_EQ(args.size(), 3); + TVM_FFI_ICHECK_EQ(args.size(), 3); ffi::Function attn_func = args[1].cast(); ffi::Function plan_func = args[2].cast(); return std::make_unique(std::move(attn_func), std::move(plan_func), attn_kind); } - LOG(FATAL) << "Cannot reach here"; + TVM_FFI_THROW(InternalError) << "Cannot reach here"; throw; } @@ -105,11 +105,11 @@ std::unique_ptr ConvertPagedPrefillTreeMaskFunc(ffi::A } ffi::String backend_name = args[0].cast(); if (backend_name == "tir") { - CHECK_EQ(args.size(), 2); + TVM_FFI_ICHECK_EQ(args.size(), 2); ffi::Function attn_func = args[1].cast(); return std::make_unique(std::move(attn_func), attn_kind); } - LOG(FATAL) << "Cannot reach here"; + TVM_FFI_THROW(InternalError) << "Cannot reach here"; throw; } @@ -120,11 +120,11 @@ std::unique_ptr ConvertRaggedPrefillTreeMaskFunc( } ffi::String backend_name = args[0].cast(); if (backend_name == "tir") { - CHECK_EQ(args.size(), 2); + TVM_FFI_ICHECK_EQ(args.size(), 2); ffi::Function attn_func = args[1].cast(); return std::make_unique(std::move(attn_func), attn_kind); } - LOG(FATAL) << "Cannot reach here"; + TVM_FFI_THROW(InternalError) << "Cannot reach here"; throw; } diff --git a/src/runtime/vm/attn_backend.h b/src/runtime/vm/attn_backend.h index 31f1ce9f4ad2..7f87f4f36d61 100644 --- a/src/runtime/vm/attn_backend.h +++ b/src/runtime/vm/attn_backend.h @@ -93,13 +93,13 @@ class PagedPrefillFunc : public AttnBackendFunc { Tensor k_rope_pos_offset, bool causal, RoPEMode rope_mode, double rotary_scale, double rotary_theta, double sm_scale, Tensor attn_output, Tensor attn_lse, TVMStreamHandle compute_stream) { - LOG(FATAL) << "MHA computation is not supported by the current backend"; + TVM_FFI_THROW(InternalError) << "MHA computation is not supported by the current backend"; } virtual void MLA(int depth, Tensor q, Tensor qo_indptr, Tensor pages, Tensor page_indptr, Tensor page_indices, Tensor length_info, bool causal, double sm_scale, Tensor attn_output, Tensor attn_lse, TVMStreamHandle compute_stream) { - LOG(FATAL) << "MLA computation is not supported by the current backend"; + TVM_FFI_THROW(InternalError) << "MLA computation is not supported by the current backend"; } virtual void BeginForward(int depth, Tensor float_workspace_buffer, Tensor int_workspace_buffer, @@ -158,11 +158,11 @@ class FlashInferPagedPrefillFunc : public PagedPrefillFunc { double rope_rcp_scale = 1 / rotary_scale; double rope_rcp_theta = 1 / rotary_theta; - ICHECK_EQ(pages.ndim(), 5); + TVM_FFI_ICHECK_EQ(pages.ndim(), 5); int H = pages->shape[2]; int N = pages->shape[3]; int D = pages->shape[4]; - CHECK(pages.IsContiguous()); + TVM_FFI_ICHECK(pages.IsContiguous()); std::vector pages_k_v_shape = {pages->shape[0], H, N, D}; std::vector pages_k_v_strides = {2 * H * N * D, N * D, D, 1}; Tensor pages_k = @@ -188,15 +188,15 @@ class FlashInferPagedPrefillFunc : public PagedPrefillFunc { Device device = q->device; TVMStreamHandle original_stream = DeviceAPI::Get(device)->GetCurrentStream(device); DeviceAPI::Get(device)->SetStream(device, compute_stream); - ICHECK_NE(qk_head_dim_, -1); - ICHECK_NE(v_head_dim_, -1); + TVM_FFI_ICHECK_NE(qk_head_dim_, -1); + TVM_FFI_ICHECK_NE(v_head_dim_, -1); int64_t H = q->shape[1]; int64_t page_size = pages->shape[1]; int64_t rope_head_dim = qk_head_dim_ - v_head_dim_; int64_t nope_head_dim = q->shape[2] - rope_head_dim; // Split q into q_nope and q_pe - CHECK(q.IsContiguous()); + TVM_FFI_ICHECK(q.IsContiguous()); std::vector q_nope_shape = {q->shape[0], H, nope_head_dim}; std::vector q_pe_shape = {q->shape[0], H, rope_head_dim}; std::vector q_strides = {H * q->shape[2], q->shape[2], 1}; @@ -206,7 +206,7 @@ class FlashInferPagedPrefillFunc : public PagedPrefillFunc { q->device, q_strides.data(), q->byte_offset + nope_head_dim * q.DataType().bytes()); // Split pages into kv_nope and kv_pe - CHECK(pages.IsContiguous()); + TVM_FFI_ICHECK(pages.IsContiguous()); std::vector kv_nope_shape = {pages->shape[0], page_size, nope_head_dim}; std::vector kv_pe_shape = {pages->shape[0], page_size, rope_head_dim}; std::vector kv_strides = {page_size * pages->shape[2], pages->shape[2], 1}; @@ -289,7 +289,7 @@ class RaggedPrefillFunc : public AttnBackendFunc { Tensor q_rope_position, Tensor k_rope_pos_offset, bool causal, RoPEMode rope_mode, double rotary_scale, double rotary_theta, double sm_scale, Tensor attn_output, Tensor attn_lse, TVMStreamHandle compute_stream) { - LOG(FATAL) << "MHA computation is not supported by the current backend"; + TVM_FFI_THROW(InternalError) << "MHA computation is not supported by the current backend"; } virtual void BeginForward(Tensor float_workspace_buffer, Tensor int_workspace_buffer, @@ -403,13 +403,13 @@ class PagedDecodeFunc : public AttnBackendFunc { Tensor length_info, Tensor k_rope_pos_offset, Tensor q_rope_position, RoPEMode rope_mode, double rotary_scale, double rotary_theta, double sm_scale, Tensor attn_output, Tensor attn_lse, TVMStreamHandle compute_stream) { - LOG(FATAL) << "MHA computation is not supported by the current backend"; + TVM_FFI_THROW(InternalError) << "MHA computation is not supported by the current backend"; } virtual void MLA(int depth, Tensor q, Tensor pages, Tensor page_indptr, Tensor page_indices, Tensor length_info, double sm_scale, Tensor attn_output, Tensor attn_lse, TVMStreamHandle compute_stream) { - LOG(FATAL) << "MLA computation is not supported by the current backend"; + TVM_FFI_THROW(InternalError) << "MLA computation is not supported by the current backend"; } virtual void BeginForward(int depth, Tensor float_workspace_buffer, Tensor int_workspace_buffer, @@ -465,11 +465,11 @@ class FlashInferPagedDecodeFunc : public PagedDecodeFunc { double rope_rcp_scale = 1 / rotary_scale; double rope_rcp_theta = 1 / rotary_theta; - ICHECK_EQ(pages.ndim(), 5); + TVM_FFI_ICHECK_EQ(pages.ndim(), 5); int H = pages->shape[2]; int N = pages->shape[3]; int D = pages->shape[4]; - CHECK(pages.IsContiguous()); + TVM_FFI_ICHECK(pages.IsContiguous()); std::vector pages_k_v_shape = {pages->shape[0], H, N, D}; std::vector pages_k_v_strides = {2 * H * N * D, N * D, D, 1}; Tensor pages_k = @@ -531,14 +531,14 @@ class PagedPrefillTreeMaskFunc : public AttnBackendFunc { Tensor q_rope_position, Tensor tree_attn_mn_indptr, Tensor tree_attn_mask, RoPEMode rope_mode, double rotary_scale, double rotary_theta, double sm_scale, Tensor attn_output, Tensor attn_lse, TVMStreamHandle compute_stream) { - LOG(FATAL) << "MHA computation is not supported by the current backend"; + TVM_FFI_THROW(InternalError) << "MHA computation is not supported by the current backend"; } virtual void MLA(Tensor q, Tensor qo_indptr, Tensor pages, Tensor page_indptr, Tensor page_indices, Tensor length_info, Tensor tree_attn_mn_indptr, Tensor tree_attn_mask, double sm_scale, Tensor attn_output, Tensor attn_lse, TVMStreamHandle compute_stream) { - LOG(FATAL) << "MLA computation is not supported by the current backend"; + TVM_FFI_THROW(InternalError) << "MLA computation is not supported by the current backend"; } virtual void BeginForward(Tensor temp_float_attn_workspace, Tensor temp_int_attn_workspace, @@ -579,13 +579,13 @@ class RaggedPrefillTreeMaskFunc : public AttnBackendFunc { Tensor q_rope_position, Tensor tree_attn_mn_indptr, Tensor tree_attn_mask, RoPEMode rope_mode, double rotary_scale, double rotary_theta, double sm_scale, Tensor attn_output, Tensor attn_lse, TVMStreamHandle compute_stream) { - LOG(FATAL) << "MHA computation is not supported by the current backend"; + TVM_FFI_THROW(InternalError) << "MHA computation is not supported by the current backend"; } virtual void MLA(Tensor q, Tensor compressed_kv, Tensor k_pe, Tensor qo_indptr, Tensor kv_indptr, Tensor tree_attn_mn_indptr, Tensor tree_attn_mask, double sm_scale, Tensor attn_output, Tensor attn_lse, TVMStreamHandle compute_stream) { - LOG(FATAL) << "MLA computation is not supported by the current backend"; + TVM_FFI_THROW(InternalError) << "MLA computation is not supported by the current backend"; } virtual void BeginForward(Tensor temp_float_attn_workspace, Tensor temp_int_attn_workspace, diff --git a/src/runtime/vm/attn_utils.h b/src/runtime/vm/attn_utils.h index 1c695a10e25d..afb962e4fc6f 100644 --- a/src/runtime/vm/attn_utils.h +++ b/src/runtime/vm/attn_utils.h @@ -77,7 +77,7 @@ inline ffi::Shape GetKVCacheShape(AttnKind attn_kind, int64_t num_total_pages, i } else if (attn_kind == AttnKind::kLinearAttn) { return {num_sequence, num_kv_heads, qk_head_dim, v_head_dim}; } - ICHECK(false); + TVM_FFI_ICHECK(false); return ffi::Shape(); } @@ -354,12 +354,12 @@ class HostMemoryVector { explicit HostMemoryVector(int64_t reserved_size, DLDataType dtype, Device device) : reserved_size_(reserved_size) { - ICHECK(DataType(dtype) == DataType::Int(32)); + TVM_FFI_ICHECK(DataType(dtype) == DataType::Int(32)); data_ = Tensor::Empty({reserved_size}, dtype, device); } void push_back(int32_t value) { - ICHECK_LE(current_size_, reserved_size_); + TVM_FFI_ICHECK_LE(current_size_, reserved_size_); if (current_size_ == reserved_size_) { reserved_size_ *= 2; Tensor new_data = Tensor::Empty({reserved_size_}, data_->dtype, data_->device); @@ -370,13 +370,13 @@ class HostMemoryVector { } const int32_t& operator[](int64_t idx) const { - ICHECK_GE(idx, 0) << "Index " << idx << " is negative."; - ICHECK_LT(idx, current_size_) << "Index " << idx << " out of bounds " << current_size_; + TVM_FFI_ICHECK_GE(idx, 0) << "Index " << idx << " is negative."; + TVM_FFI_ICHECK_LT(idx, current_size_) << "Index " << idx << " out of bounds " << current_size_; return static_cast(data_->data)[idx]; } int32_t back() const { - ICHECK_GT(current_size_, 0) << "Vector is empty"; + TVM_FFI_ICHECK_GT(current_size_, 0) << "Vector is empty"; return static_cast(data_->data)[current_size_ - 1]; } @@ -429,7 +429,7 @@ class PagedKVCacheAuxDataManager { device_(device), preferred_host_device_(preferred_host_device), copy_stream_(copy_stream) { - ICHECK(DataType(dtype_aux) == DataType::Int(32)); + TVM_FFI_ICHECK(DataType(dtype_aux) == DataType::Int(32)); } virtual ~PagedKVCacheAuxDataManager() = default; @@ -661,7 +661,7 @@ class PlainPagedKVCacheAuxDataManager : public PagedKVCacheAuxDataManager { HostMemoryVector* sliding_window_offset, HostMemoryVector* sink_size, int depth) final { int n_elem = last_page_len->size(); - ICHECK_GT(n_elem, 0); + TVM_FFI_ICHECK_GT(n_elem, 0); Tensor view = length_info_on_depths_device_[depth].CreateView({3, n_elem}, dtype_aux_); ffi::Shape copy_shape{n_elem}; CopyVecDataToArray(view, last_page_len->data(), copy_shape); @@ -687,7 +687,7 @@ class PlainPagedKVCacheAuxDataManager : public PagedKVCacheAuxDataManager { Tensor CopyCommitSrcDstPosInPageTableAsync(HostMemoryVector* src_data, HostMemoryVector* dst_data) final { int n_elem = src_data->size(); - ICHECK_GT(n_elem, 0); + TVM_FFI_ICHECK_GT(n_elem, 0); Tensor view = commit_copy_src_dst_pos_in_page_table_device_.CreateView({2, n_elem}, dtype_aux_); ffi::Shape copy_shape{n_elem}; CopyVecDataToArray(view, src_data->data(), copy_shape); @@ -717,7 +717,7 @@ class PlainPagedKVCacheAuxDataManager : public PagedKVCacheAuxDataManager { void* nptr = workspace->GetNativePtr(array); uint64_t copy_size; if (shape.defined()) { - ICHECK_EQ(shape.value().size(), 1); + TVM_FFI_ICHECK_EQ(shape.value().size(), 1); copy_size = shape.value()->data[0] * sizeof(int32_t); } else { copy_size = DeviceAPI::Get(array->device)->GetDataSize(*array.operator->()); @@ -728,7 +728,7 @@ class PlainPagedKVCacheAuxDataManager : public PagedKVCacheAuxDataManager { #endif if (shape.defined()) { - ICHECK_EQ(shape.value().size(), 1); + TVM_FFI_ICHECK_EQ(shape.value().size(), 1); copy_dst.ndim = 1; copy_dst.shape = const_cast(shape.value()->data); } diff --git a/src/runtime/vm/builtin.cc b/src/runtime/vm/builtin.cc index 06eaf35ac908..35cc261e4d43 100644 --- a/src/runtime/vm/builtin.cc +++ b/src/runtime/vm/builtin.cc @@ -60,7 +60,7 @@ Tensor AllocShapeHeap(void* ctx_ptr, int64_t size) { if (vm->devices[0].device_type == kDLHexagon) { host_device_index = 0; } else { - ICHECK_EQ(vm->devices[host_device_index].device_type, kDLCPU); + TVM_FFI_ICHECK_EQ(vm->devices[host_device_index].device_type, kDLCPU); } auto* alloc = vm->allocators[host_device_index]; return alloc->Empty({size}, DLDataType{kDLInt, 64, 1}, vm->devices[host_device_index]); @@ -95,17 +95,18 @@ void MatchPrimValue(int64_t input_value, DLTensor* heap, int code_value, int64_t MatchShapeCode code = static_cast(code_value); if (code == MatchShapeCode::kAssertEqualToImm) { - CHECK_EQ(input_value, reg) << "RuntimeError: " << err_ctx.value_or("") << " match_cast error, " - << " PrimValue mismatch to specified constant."; + TVM_FFI_CHECK_EQ(input_value, reg, RuntimeError) + << err_ctx.value_or("") << " match_cast error, " + << " PrimValue mismatch to specified constant."; } else if (code == MatchShapeCode::kStoreToHeap) { heap_data[reg] = input_value; } else if (code == MatchShapeCode::kNoOp) { } else if (code == MatchShapeCode::kAssertEqualToLoad) { - CHECK_EQ(input_value, heap_data[reg]) - << "RuntimeError: " << err_ctx.value_or("") << " match_cast error, " + TVM_FFI_CHECK_EQ(input_value, heap_data[reg], RuntimeError) + << err_ctx.value_or("") << " match_cast error, " << " PrimValue mismatch to a previous populated value."; } else { - LOG(FATAL) << "Unknown match shape code: " << static_cast(code); + TVM_FFI_THROW(InternalError) << "Unknown match shape code: " << static_cast(code); } } @@ -136,30 +137,30 @@ void MatchShape(ffi::PackedArgs args, ffi::Any* rv) { int64_t* heap_data = heap.has_value() ? static_cast((*heap)->data) : nullptr; int64_t size = args[2].cast(); const int64_t kBeginCode = 3; - ICHECK_LE(kBeginCode + size * 2, args.size()); + TVM_FFI_ICHECK_LE(kBeginCode + size * 2, args.size()); // a function that lazily get context for error reporting const int64_t kErrorContextOffset = kBeginCode + size * 2; ffi::Optional err_ctx = args[kErrorContextOffset].cast(); - CHECK_EQ(input_shape.size(), size) - << "RuntimeError: " << err_ctx.value_or("") << " match_cast shape size mismatch."; + TVM_FFI_CHECK_EQ(input_shape.size(), size, RuntimeError) + << err_ctx.value_or("") << " match_cast shape size mismatch."; for (int64_t i = 0; i < size; ++i) { MatchShapeCode code = static_cast(args[kBeginCode + i * 2].cast()); int64_t reg = args[kBeginCode + i * 2 + 1].cast(); if (code == MatchShapeCode::kAssertEqualToImm) { - CHECK_EQ(input_shape[i], reg) - << "RuntimeError: " << err_ctx.value_or("") << " match_cast error, " + TVM_FFI_CHECK_EQ(input_shape[i], reg, RuntimeError) + << err_ctx.value_or("") << " match_cast error, " << " shape[" << i << "]" << " mismatch to specified constant."; } else if (code == MatchShapeCode::kStoreToHeap) { heap_data[reg] = input_shape[i]; } else if (code == MatchShapeCode::kNoOp) { } else { - ICHECK(code == MatchShapeCode::kAssertEqualToLoad); - CHECK_EQ(input_shape[i], heap_data[reg]) - << "RuntimeError: " << err_ctx.value_or("") << " match_cast error, " + TVM_FFI_ICHECK(code == MatchShapeCode::kAssertEqualToLoad); + TVM_FFI_CHECK_EQ(input_shape[i], heap_data[reg], RuntimeError) + << err_ctx.value_or("") << " match_cast error, " << " shape[" << i << "]" << " mismatch to a previous populated value."; } @@ -189,7 +190,7 @@ int64_t MakePrimValue(DLTensor* heap, int shape_code, int64_t reg) { } else if (code == MakeShapeCode::kLoadShape) { return heap_data[reg]; } else { - LOG(FATAL) << "Invalid shape code: " << shape_code; + TVM_FFI_THROW(InternalError) << "Invalid shape code: " << shape_code; } } @@ -220,7 +221,7 @@ void MakeShape(ffi::PackedArgs args, ffi::Any* rv) { if (code == MakeShapeCode::kUseImm) { shape[i] = reg; } else { - ICHECK(code == MakeShapeCode::kLoadShape); + TVM_FFI_ICHECK(code == MakeShapeCode::kLoadShape); shape[i] = heap_data[reg]; } } @@ -254,19 +255,19 @@ void CheckTensorInfo(ffi::PackedArgs args, ffi::Any* rv) { } auto opt_ptr = arg.try_cast(); - CHECK(opt_ptr.has_value()) << "TypeError: " << err_ctx.value_or("") << " expect a Tensor but get " - << arg.GetTypeKey(); + TVM_FFI_CHECK(opt_ptr.has_value(), TypeError) + << err_ctx.value_or("") << " expect a Tensor but get " << arg.GetTypeKey(); DLTensor* ptr = opt_ptr.value(); if (ndim != -1) { - CHECK(ptr->ndim == ndim) << "ValueError: " << err_ctx.value_or("") - << " expect Tensor with ndim " << ndim << " but get " << ptr->ndim; + TVM_FFI_CHECK(ptr->ndim == ndim, ValueError) + << err_ctx.value_or("") << " expect Tensor with ndim " << ndim << " but get " << ptr->ndim; } if (dtype != DataType::Void()) { - CHECK(DataType(ptr->dtype) == dtype) - << "ValueError: " << err_ctx.value_or("") << " expect Tensor with dtype " << dtype - << " but get " << DataType(ptr->dtype); + TVM_FFI_CHECK(DataType(ptr->dtype) == dtype, ValueError) + << err_ctx.value_or("") << " expect Tensor with dtype " << dtype << " but get " + << DataType(ptr->dtype); } } @@ -284,12 +285,11 @@ TVM_FFI_STATIC_INIT_BLOCK() { void CheckShapeInfo(ObjectRef arg, int ndim, ffi::Optional err_ctx) { // a function that lazily get context for error reporting auto* ptr = arg.as(); - CHECK(ptr != nullptr) << "TypeError: " << err_ctx.value_or("") << " expect a Shape but get " - << arg->GetTypeKey(); + TVM_FFI_CHECK(ptr != nullptr, TypeError) + << err_ctx.value_or("") << " expect a Shape but get " << arg->GetTypeKey(); if (ndim != -1) { - CHECK(ptr->size == static_cast(ndim)) - << "ValueError: " << err_ctx.value_or("") << " expect Shape with ndim " << ndim - << " but get " << ptr->size; + TVM_FFI_CHECK(ptr->size == static_cast(ndim), ValueError) + << err_ctx.value_or("") << " expect Shape with ndim " << ndim << " but get " << ptr->size; } } @@ -306,8 +306,9 @@ TVM_FFI_STATIC_INIT_BLOCK() { */ void CheckPrimValueInfo(ffi::AnyView arg, DataType dtype, ffi::Optional err_ctx) { if (auto opt_obj = arg.as()) { - LOG(FATAL) << "TypeError: " << err_ctx.value_or("") << ", expected dtype " << dtype - << ", but received ObjectRef of type " << opt_obj.value()->GetTypeKey(); + TVM_FFI_THROW(TypeError) << err_ctx.value_or("") << ", expected dtype " << dtype + << ", but received ObjectRef of type " + << opt_obj.value()->GetTypeKey(); } else if (dtype.is_bool()) { arg.cast(); } else if (dtype.is_int()) { @@ -319,7 +320,7 @@ void CheckPrimValueInfo(ffi::AnyView arg, DataType dtype, ffi::Optional(); } else { - LOG(FATAL) << "TypeError: " << err_ctx.value_or("") << ", unsupported dtype " << dtype; + TVM_FFI_THROW(TypeError) << err_ctx.value_or("") << ", unsupported dtype " << dtype; } } @@ -337,10 +338,10 @@ TVM_FFI_STATIC_INIT_BLOCK() { void CheckTupleInfo(ObjectRef arg, int64_t size, ffi::Optional err_ctx) { // a function that lazily get context for error reporting auto* ptr = arg.as(); - CHECK(ptr != nullptr) << "TypeError: " << err_ctx.value_or("") << " expect a Tuple but get " - << arg->GetTypeKey(); - CHECK(static_cast(ptr->size()) == size) - << "ValueError: " << err_ctx.value_or("") << " expect a Tuple with " << size << " elements, " + TVM_FFI_CHECK(ptr != nullptr, TypeError) + << err_ctx.value_or("") << " expect a Tuple but get " << arg->GetTypeKey(); + TVM_FFI_CHECK(static_cast(ptr->size()) == size, ValueError) + << err_ctx.value_or("") << " expect a Tuple with " << size << " elements, " << " but get a Tuple with " << ptr->size() << " elements."; } @@ -357,8 +358,8 @@ TVM_FFI_STATIC_INIT_BLOCK() { void CheckFuncInfo(ObjectRef arg, ffi::Optional err_ctx) { // a function that lazily get context for error reporting bool is_func = arg.as() || arg.as(); - CHECK(is_func) << "TypeError: " << err_ctx.value_or("") << " expect a Function but get " - << arg->GetTypeKey(); + TVM_FFI_CHECK(is_func, TypeError) + << err_ctx.value_or("") << " expect a Function but get " << arg->GetTypeKey(); } TVM_FFI_STATIC_INIT_BLOCK() { @@ -373,7 +374,7 @@ Storage VMAllocStorage(void* ctx_ptr, ffi::Shape buffer_shape, Index device_inde DLDataType dtype_hint, ffi::String mem_scope) { VirtualMachine* vm = static_cast(ctx_ptr); - ICHECK_LT(device_index, vm->devices.size()) + TVM_FFI_ICHECK_LT(device_index, vm->devices.size()) << "The device index is out of VM physical devices list"; if (device_index == -1) { @@ -382,7 +383,7 @@ Storage VMAllocStorage(void* ctx_ptr, ffi::Shape buffer_shape, Index device_inde } auto* alloc = vm->allocators[device_index]; - ICHECK(alloc) << "Did you forget to init the VirtualMachine with devices?"; + TVM_FFI_ICHECK(alloc) << "Did you forget to init the VirtualMachine with devices?"; auto buffer = alloc->Alloc(vm->devices[device_index], buffer_shape, dtype_hint, mem_scope); @@ -473,7 +474,7 @@ void RegisterPyFunc(const std::string& name, ffi::Function func) { py_func_regis ffi::Function GetPyFunc(const std::string& name) { auto it = py_func_registry.find(name); if (it == py_func_registry.end()) { - LOG(FATAL) << "Python function '" << name << "' not found in registry"; + TVM_FFI_THROW(InternalError) << "Python function '" << name << "' not found in registry"; } return it->second; } @@ -486,12 +487,13 @@ ffi::Function GetPyFunc(const std::string& name) { void CallPyFunc(ffi::PackedArgs args, ffi::Any* rv) { // args[0] should be a tuple containing (func_name, args_tuple) if (args.size() != 1) { - LOG(FATAL) << "vm.builtin.call_py_func expects exactly 1 argument (tuple)"; + TVM_FFI_THROW(InternalError) << "vm.builtin.call_py_func expects exactly 1 argument (tuple)"; } auto tuple_arg = args[0].cast>(); if (tuple_arg.size() != 2) { - LOG(FATAL) << "vm.builtin.call_py_func tuple should contain (func_name, args)"; + TVM_FFI_THROW(InternalError) + << "vm.builtin.call_py_func tuple should contain (func_name, args)"; } // Get function name @@ -556,7 +558,8 @@ bool ReadIfCond(ffi::AnyView cond) { if (arr->device.device_type != kDLCPU) { arr = arr.CopyTo(DLDevice{kDLCPU, 0}); } - ICHECK(arr->dtype.code == kDLInt || arr->dtype.code == kDLUInt || arr->dtype.code == kDLBool); + TVM_FFI_ICHECK(arr->dtype.code == kDLInt || arr->dtype.code == kDLUInt || + arr->dtype.code == kDLBool); int64_t result; switch (arr->dtype.bits) { case 1: { @@ -580,7 +583,7 @@ bool ReadIfCond(ffi::AnyView cond) { break; } default: - LOG(FATAL) << "Unknown scalar int type: " << DLDataTypeToString(arr->dtype); + TVM_FFI_THROW(InternalError) << "Unknown scalar int type: " << DLDataTypeToString(arr->dtype); throw; } return result != 0; @@ -599,15 +602,17 @@ TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def_packed( "vm.builtin.invoke_debug_func", [](ffi::PackedArgs args, ffi::Any* rv) -> void { - ICHECK_GE(args.size(), 3); + TVM_FFI_ICHECK_GE(args.size(), 3); int num_args = args.size() - 3; ObjectRef io_effect = args[0].cast(); - ICHECK(!io_effect.defined()) << "ValueError: IOEffect is expected to be lowered to None."; + TVM_FFI_CHECK(!io_effect.defined(), ValueError) + << "IOEffect is expected to be lowered to None."; ffi::String debug_func_name = args[1].cast(); const auto debug_func = tvm::ffi::Function::GetGlobal(debug_func_name); - CHECK(debug_func.has_value()) << "ValueError: " << debug_func_name << " is not found. " - << "Use the decorator `@tvm.register_global_func(\"" - << debug_func_name << "\")` to register it."; + TVM_FFI_CHECK(debug_func.has_value(), ValueError) + << debug_func_name << " is not found. " + << "Use the decorator `@tvm.register_global_func(\"" << debug_func_name + << "\")` to register it."; ffi::String line_info = args[2].cast(); std::vector call_args(num_args + 1); { @@ -648,8 +653,8 @@ TVM_FFI_STATIC_INIT_BLOCK() { arr = data.CopyTo(DLDevice{kDLCPU, 0}); } - ICHECK_EQ(arr->ndim, 1); - ICHECK_EQ(arr->dtype.code, kDLInt); + TVM_FFI_ICHECK_EQ(arr->ndim, 1); + TVM_FFI_ICHECK_EQ(arr->dtype.code, kDLInt); std::vector out_shape; for (int i = 0; i < arr.Shape()[0]; ++i) { @@ -668,7 +673,8 @@ TVM_FFI_STATIC_INIT_BLOCK() { break; } default: - LOG(FATAL) << "Unknown scalar int type: " << DLDataTypeToString(arr->dtype); + TVM_FFI_THROW(InternalError) + << "Unknown scalar int type: " << DLDataTypeToString(arr->dtype); throw; } out_shape.push_back(result); diff --git a/src/runtime/vm/cuda/cuda_graph_builtin.cc b/src/runtime/vm/cuda/cuda_graph_builtin.cc index 9523fd3f4b30..41f6da7e3cbe 100644 --- a/src/runtime/vm/cuda/cuda_graph_builtin.cc +++ b/src/runtime/vm/cuda/cuda_graph_builtin.cc @@ -253,7 +253,7 @@ TVM_FFI_STATIC_INIT_BLOCK() { refl::GlobalDef() .def_packed("vm.builtin.cuda_graph.run_or_capture", [](ffi::PackedArgs args, ffi::Any* rv) { - ICHECK(args.size() == 5 || args.size() == 4); + TVM_FFI_ICHECK(args.size() == 5 || args.size() == 4); VirtualMachine* vm = VirtualMachine::GetContextPtr(args[0]); auto extension = vm->GetOrCreateExtension(); auto capture_func = args[1].cast(); @@ -267,7 +267,7 @@ TVM_FFI_STATIC_INIT_BLOCK() { shape_expr); }) .def_packed("vm.builtin.cuda_graph.get_cached_alloc", [](ffi::PackedArgs args, ffi::Any* rv) { - ICHECK_EQ(args.size(), 3); + TVM_FFI_ICHECK_EQ(args.size(), 3); VirtualMachine* vm = VirtualMachine::GetContextPtr(args[0]); auto extension = vm->GetOrCreateExtension(); auto alloc_func = args[1].cast(); diff --git a/src/runtime/vm/executable.cc b/src/runtime/vm/executable.cc index 71d3f58ecefc..ada04a1024dd 100644 --- a/src/runtime/vm/executable.cc +++ b/src/runtime/vm/executable.cc @@ -41,9 +41,9 @@ namespace vm { constexpr uint64_t kTVMVMBytecodeMagic = 0xD225DE2F4214151D; constexpr uint64_t kTVMVMBytecodeMagicV2 = 0xD225DE2F4214151E; -#define STREAM_CHECK(val, section) \ - ICHECK(val) << "Invalid VM file format in the " << section << " section." \ - << "\n"; +#define STREAM_CHECK(val, section) \ + TVM_FFI_ICHECK(val) << "Invalid VM file format in the " << section << " section." \ + << "\n"; std::string VMExecutable::Stats() const { std::ostringstream oss; @@ -89,7 +89,7 @@ std::string VMExecutable::Stats() const { oss << dtype; oss << ", "; } else { - LOG(FATAL) << "Unsupported constant pool type " << it.GetTypeKey(); + TVM_FFI_THROW(InternalError) << "Unsupported constant pool type " << it.GetTypeKey(); } } if (!constants.empty()) oss.seekp(-2, oss.cur); @@ -107,9 +107,9 @@ std::string VMExecutable::Stats() const { } void VMExecutable::SetInstructionData(Index i, Index j, ExecWord val) { - ICHECK_LT(i, instr_offset.size()); + TVM_FFI_ICHECK_LT(i, instr_offset.size()); Index instr_idx = instr_offset[i]; - ICHECK_LT(instr_idx + j, instr_data.size()); + TVM_FFI_ICHECK_LT(instr_idx + j, instr_data.size()); instr_data[instr_idx + j] = val; } @@ -138,7 +138,7 @@ Instruction VMExecutable::GetInstruction(Index i) const { return Instruction::If(cond, false_offset); } default: - LOG(FATAL) << "should never hit this case: " << static_cast(op); + TVM_FFI_THROW(InternalError) << "should never hit this case: " << static_cast(op); break; } return Instruction(); @@ -292,7 +292,7 @@ void VMExecutable::SaveConstantSection(support::Stream* strm) const { strm->Write(ffi::TypeIndex::kTVMFFIDataType); strm->Write(opt_dtype.value()); } else { - LOG(FATAL) << "Unsupported constant pool type " << it.GetTypeKey(); + TVM_FFI_THROW(InternalError) << "Unsupported constant pool type " << it.GetTypeKey(); } } } @@ -379,8 +379,9 @@ void VMExecutable::LoadConstantSection(support::Stream* strm) { cell = value; this->constants.push_back(cell); } else { - LOG(FATAL) << "Constant pool can only contain Tensor and DLDataType, but got " - << ffi::TypeIndexToTypeKey(constant_type) << " when loading the VM constant pool."; + TVM_FFI_THROW(InternalError) + << "Constant pool can only contain Tensor and DLDataType, but got " + << ffi::TypeIndexToTypeKey(constant_type) << " when loading the VM constant pool."; } } } @@ -449,7 +450,7 @@ ffi::String VMExecutable::AsText() const { case Instruction::ArgKind::kFuncIdx: return "f[" + get_func_name(arg.value()) + "]"; default: - LOG(FATAL) << "Wrong instruction kind: " << static_cast(arg.kind()); + TVM_FFI_THROW(InternalError) << "Wrong instruction kind: " << static_cast(arg.kind()); return ""; } }; @@ -466,7 +467,7 @@ ffi::String VMExecutable::AsText() const { os << "@" << gfunc.name << " num_inputs=" << gfunc.num_args << " vm_tir_func;\n\n"; continue; } - ICHECK(gfunc.kind == VMFuncInfo::FuncKind::kVMFunc); + TVM_FFI_ICHECK(gfunc.kind == VMFuncInfo::FuncKind::kVMFunc); os << "@" << gfunc.name << ":\n"; size_t start_instr = gfunc.start_instr; size_t end_instr = gfunc.end_instr; @@ -496,7 +497,8 @@ ffi::String VMExecutable::AsText() const { break; } default: - LOG(FATAL) << "should never hit this case: " << static_cast(instr.op); + TVM_FFI_THROW(InternalError) + << "should never hit this case: " << static_cast(instr.op); break; } } @@ -529,7 +531,7 @@ ffi::String VMExecutable::AsPython() const { return "ib.f(" + get_func_name(arg.value()) + ")"; } default: - LOG(FATAL) << "Wrong instruction kind: " << static_cast(arg.kind()); + TVM_FFI_THROW(InternalError) << "Wrong instruction kind: " << static_cast(arg.kind()); return ""; } }; @@ -545,7 +547,7 @@ ffi::String VMExecutable::AsPython() const { if (gfunc.kind == VMFuncInfo::FuncKind::kVMTIRFunc) { continue; } - ICHECK(gfunc.kind == VMFuncInfo::FuncKind::kVMFunc); + TVM_FFI_ICHECK(gfunc.kind == VMFuncInfo::FuncKind::kVMFunc); os << "with ib.function(\"" << gfunc.name << "\", num_inputs=" << gfunc.num_args << "):\n"; size_t start_instr = gfunc.start_instr; @@ -575,7 +577,8 @@ ffi::String VMExecutable::AsPython() const { break; } default: - LOG(FATAL) << "should never hit this case: " << static_cast(instr.op); + TVM_FFI_THROW(InternalError) + << "should never hit this case: " << static_cast(instr.op); break; } } diff --git a/src/runtime/vm/hexagon/builtin.cc b/src/runtime/vm/hexagon/builtin.cc index 72929dd3d8f2..54fd70b2800f 100644 --- a/src/runtime/vm/hexagon/builtin.cc +++ b/src/runtime/vm/hexagon/builtin.cc @@ -45,9 +45,9 @@ TVM_FFI_STATIC_INIT_BLOCK() { void* src = sptr->data; int ret = DMA_RETRY; - CHECK_EQ(GetDataSize(*dptr), GetDataSize(*sptr)); + TVM_FFI_ICHECK_EQ(GetDataSize(*dptr), GetDataSize(*sptr)); auto size = GetDataSize(*dptr); - ICHECK(size > 0); + TVM_FFI_ICHECK(size > 0); if (bypass_cache) qurt_mem_cache_clean(reinterpret_cast(src), size, QURT_MEM_CACHE_INVALIDATE, QURT_MEM_DCACHE); @@ -55,12 +55,12 @@ TVM_FFI_STATIC_INIT_BLOCK() { ret = tvm::runtime::hexagon::HexagonDeviceAPI::Global()->UserDMA()->Copy( queue_id, dst, src, size, bypass_cache); } while (ret == DMA_RETRY); - CHECK(ret == DMA_SUCCESS); + TVM_FFI_ICHECK(ret == DMA_SUCCESS); }) .def("vm.builtin.hexagon.dma_wait", [](ffi::AnyView vm_ptr, int queue_id, int inflight_dma, bool bypass_cache, [[maybe_unused]] Tensor src_arr, [[maybe_unused]] Tensor dst_arr) { - ICHECK(inflight_dma >= 0); + TVM_FFI_ICHECK(inflight_dma >= 0); tvm::runtime::hexagon::HexagonDeviceAPI::Global()->UserDMA()->Wait(queue_id, inflight_dma); if (bypass_cache) { const DLTensor* dptr = dst_arr.operator->(); diff --git a/src/runtime/vm/kv_state.cc b/src/runtime/vm/kv_state.cc index 5d04139a32c8..b82d934f5c67 100644 --- a/src/runtime/vm/kv_state.cc +++ b/src/runtime/vm/kv_state.cc @@ -40,7 +40,7 @@ TVM_FFI_STATIC_INIT_BLOCK() { .def_method("vm.builtin.kv_state_popn", &KVStateObj::PopN) .def_packed("vm.builtin.kv_state_begin_forward", [](ffi::PackedArgs args, ffi::Any* rv) { - CHECK(args.size() == 3 || args.size() == 4) + TVM_FFI_ICHECK(args.size() == 3 || args.size() == 4) << "KVState BeginForward only accepts 3 or 4 arguments"; KVState kv_state = args[0].cast(); ffi::Shape seq_ids = args[1].cast(); diff --git a/src/runtime/vm/lm_support.cc b/src/runtime/vm/lm_support.cc index e4bdb7e86607..1fe2c43ac0f6 100644 --- a/src/runtime/vm/lm_support.cc +++ b/src/runtime/vm/lm_support.cc @@ -83,9 +83,9 @@ class AttentionKVCacheLegacyObj : public Object { * \param shape The cached values. */ Tensor View(const ffi::Shape& shape) { - CHECK_EQ(shape[0], fill_count) << "Requested shape do not match the filled count"; + TVM_FFI_ICHECK_EQ(shape[0], fill_count) << "Requested shape do not match the filled count"; for (int i = 1; i < this->data->ndim; ++i) { - CHECK_EQ(shape[i], data->shape[i]) << "Dimension " << i << " mismatch"; + TVM_FFI_ICHECK_EQ(shape[i], data->shape[i]) << "Dimension " << i << " mismatch"; } return data.CreateView(shape, data->dtype); } @@ -98,15 +98,16 @@ class AttentionKVCacheLegacyObj : public Object { /** pop n entries */ void PopN(size_t n) { - ICHECK_LE(n, fill_count); + TVM_FFI_ICHECK_LE(n, fill_count); this->fill_count -= n; } void Update(Tensor value) { - CHECK(data.DataType() == value.DataType()) << "dtype mismatch"; - CHECK_EQ(value->shape[0], fill_count) << "Requested shape do not match the filled count"; - ICHECK(data.IsContiguous()); - ICHECK(value.IsContiguous()); + TVM_FFI_ICHECK(data.DataType() == value.DataType()) << "dtype mismatch"; + TVM_FFI_ICHECK_EQ(value->shape[0], fill_count) + << "Requested shape do not match the filled count"; + TVM_FFI_ICHECK(data.IsContiguous()); + TVM_FFI_ICHECK(value.IsContiguous()); DLTensor copy_dst = *(data.operator->()); copy_dst.byte_offset = 0; @@ -122,8 +123,9 @@ class AttentionKVCacheLegacyObj : public Object { * \param num_attention_sinks number of sinks to store (https://arxiv.org/abs/2309.17453). */ void WindowOverride(Tensor value, int64_t max_cache_size, int64_t num_attention_sinks = 0) { - CHECK(data.DataType() == value.DataType()) << "dtype mismatch"; - CHECK_LE(value->shape[0], max_cache_size - num_attention_sinks) << "dim 0 of value too large"; + TVM_FFI_ICHECK(data.DataType() == value.DataType()) << "dtype mismatch"; + TVM_FFI_ICHECK_LE(value->shape[0], max_cache_size - num_attention_sinks) + << "dim 0 of value too large"; // reallocate cache if (fill_count + value->shape[0] <= max_cache_size) { int64_t reserved_slots = data->shape[0]; @@ -139,7 +141,7 @@ class AttentionKVCacheLegacyObj : public Object { } } // copy into the current position. - ICHECK(data.IsContiguous()); + TVM_FFI_ICHECK(data.IsContiguous()); int64_t num_elements_to_copy = std::min(value->shape[0], max_cache_size - window_attention_current_pos); @@ -147,7 +149,7 @@ class AttentionKVCacheLegacyObj : public Object { std::vector shape; shape.push_back(num_elements_to_copy); for (int i = 1; i < data->ndim; ++i) { - CHECK_EQ(value->shape[i], data->shape[i]) << "Dimension " << i << " mismatch"; + TVM_FFI_ICHECK_EQ(value->shape[i], data->shape[i]) << "Dimension " << i << " mismatch"; num_elements_p_entry *= data->shape[i]; shape.push_back(data->shape[i]); } @@ -170,8 +172,8 @@ class AttentionKVCacheLegacyObj : public Object { // copy the remainder to the beginning of the cache if (num_elements_to_copy < value->shape[0]) { - ICHECK_EQ(this->fill_count, max_cache_size); - ICHECK_EQ(this->fill_count, this->window_attention_current_pos); + TVM_FFI_ICHECK_EQ(this->fill_count, max_cache_size); + TVM_FFI_ICHECK_EQ(this->fill_count, this->window_attention_current_pos); shape[0] = value->shape[0] - num_elements_to_copy; num_filled_elements = num_elements_to_copy * num_elements_p_entry; @@ -197,7 +199,7 @@ class AttentionKVCacheLegacyObj : public Object { * \param value The value to be appended. */ void Append(Tensor value) { - CHECK(data.DataType() == value.DataType()) << "dtype mismatch"; + TVM_FFI_ICHECK(data.DataType() == value.DataType()) << "dtype mismatch"; // reallocate cache int64_t reserved_slots = data->shape[0]; while (fill_count + value->shape[0] > reserved_slots) { @@ -211,12 +213,12 @@ class AttentionKVCacheLegacyObj : public Object { this->data = new_data; } // copy into the fill count position. - ICHECK_LE(fill_count + value->shape[0], data->shape[0]); - ICHECK(data.IsContiguous()); + TVM_FFI_ICHECK_LE(fill_count + value->shape[0], data->shape[0]); + TVM_FFI_ICHECK(data.IsContiguous()); int64_t num_filled_elements = fill_count; for (int i = 1; i < data->ndim; ++i) { - CHECK_EQ(value->shape[i], data->shape[i]) << "Dimension " << i << " mismatch"; + TVM_FFI_ICHECK_EQ(value->shape[i], data->shape[i]) << "Dimension " << i << " mismatch"; num_filled_elements *= data->shape[i]; } // create a view of copy dest to copy the value into it. @@ -317,8 +319,8 @@ TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def_packed( "vm.builtin.attention_kv_cache_view", [](ffi::PackedArgs args, ffi::Any* rv) { - CHECK(args.size() == 1 || args.size() == 2) - << "ValueError: `vm.builtin.attention_kv_cache_view` expects 1 or 2 arguments, but got " + TVM_FFI_CHECK(args.size() == 1 || args.size() == 2, ValueError) + << "`vm.builtin.attention_kv_cache_view` expects 1 or 2 arguments, but got " << args.size() << "."; AttentionKVCacheLegacy cache = args[0].cast(); if (args.size() == 2) { @@ -359,17 +361,17 @@ TVM_FFI_STATIC_INIT_BLOCK() { // NOTE this is a built-in highly related to LM so we put it here. int SampleTopPFromLogits(Tensor logits, double temperature, double top_p, double uniform_sample) { - ICHECK(logits.IsContiguous()); - ICHECK(logits.DataType() == DataType::Float(32)); + TVM_FFI_ICHECK(logits.IsContiguous()); + TVM_FFI_ICHECK(logits.DataType() == DataType::Float(32)); if (logits->device.device_type != kDLCPU) { logits = logits.CopyTo(DLDevice{kDLCPU, 0}); } - ICHECK(logits->device.device_type == kDLCPU); + TVM_FFI_ICHECK(logits->device.device_type == kDLCPU); for (int i = 0; i < logits->ndim - 1; ++i) { - ICHECK_EQ(logits->shape[i], 1) << "The leading dimensions of logits must be 1"; + TVM_FFI_ICHECK_EQ(logits->shape[i], 1) << "The leading dimensions of logits must be 1"; } std::vector> data; @@ -415,7 +417,7 @@ int SampleTopPFromLogits(Tensor logits, double temperature, double top_p, double return it->second; } } - ICHECK_LE(uniform_sample, data[0].first); + TVM_FFI_ICHECK_LE(uniform_sample, data[0].first); return data[0].second; } @@ -425,17 +427,17 @@ TVM_FFI_STATIC_INIT_BLOCK() { } int SampleTopPFromProb(Tensor prob, double top_p, double uniform_sample) { - ICHECK(prob.IsContiguous()); - ICHECK(prob.DataType() == DataType::Float(32)); + TVM_FFI_ICHECK(prob.IsContiguous()); + TVM_FFI_ICHECK(prob.DataType() == DataType::Float(32)); if (prob->device.device_type != kDLCPU) { prob = prob.CopyTo(DLDevice{kDLCPU, 0}); } - ICHECK(prob->device.device_type == kDLCPU); + TVM_FFI_ICHECK(prob->device.device_type == kDLCPU); for (int i = 0; i < prob->ndim - 1; ++i) { - ICHECK_EQ(prob->shape[i], 1) << "The leading dimensions of logits must be 1"; + TVM_FFI_ICHECK_EQ(prob->shape[i], 1) << "The leading dimensions of logits must be 1"; } // Key observation: when we are doing top_p sampling @@ -510,9 +512,10 @@ int SampleTopPFromProb(Tensor prob, double top_p, double uniform_sample) { data.reserve(ndata); int64_t sampled_index = sample_top_p_with_filter(0.0f); if (sampled_index < 0 && is_all_nan()) { - LOG(FATAL) << "The output probabilities are all NaNs, can not sample from it"; + TVM_FFI_THROW(InternalError) << "The output probabilities are all NaNs, can not sample from it"; } else if (sampled_index < 0) { - LOG(FATAL) << "Cannot sample from the given probability distribution due to unknown reason"; + TVM_FFI_THROW(InternalError) + << "Cannot sample from the given probability distribution due to unknown reason"; } return sampled_index; } @@ -523,8 +526,8 @@ TVM_FFI_STATIC_INIT_BLOCK() { } Tensor MultinomialFromUniform(Tensor prob, Tensor uniform_sample) { - ICHECK(prob.IsContiguous()); - ICHECK(uniform_sample.IsContiguous()); + TVM_FFI_ICHECK(prob.IsContiguous()); + TVM_FFI_ICHECK(uniform_sample.IsContiguous()); if (prob->device.device_type != kDLCPU) { prob = prob.CopyTo(DLDevice{kDLCPU, 0}); @@ -533,8 +536,8 @@ Tensor MultinomialFromUniform(Tensor prob, Tensor uniform_sample) { uniform_sample = uniform_sample.CopyTo(DLDevice{kDLCPU, 0}); } - ICHECK(prob->device.device_type == kDLCPU); - ICHECK(uniform_sample->device.device_type == kDLCPU); + TVM_FFI_ICHECK(prob->device.device_type == kDLCPU); + TVM_FFI_ICHECK(uniform_sample->device.device_type == kDLCPU); int64_t batch_size = prob->shape[0]; int64_t vocab_size = prob->shape[prob->ndim - 1]; @@ -564,12 +567,12 @@ TVM_FFI_STATIC_INIT_BLOCK() { // This is an inplace operation. void ApplyRepetitionPenalty(Tensor logits, Tensor token_ids, double penalty) { - ICHECK(logits.IsContiguous()); - ICHECK(token_ids.IsContiguous()); - ICHECK(logits.DataType() == DataType::Float(32)) << "Logits data type is not float32!"; - ICHECK(token_ids.DataType() == DataType::Int(32)) << "token ids must be int32!"; - ICHECK(logits->device.device_type == kDLCPU) << "logits device must be CPU!"; - ICHECK(token_ids->device.device_type == kDLCPU) << "token_ids device must be CPU!"; + TVM_FFI_ICHECK(logits.IsContiguous()); + TVM_FFI_ICHECK(token_ids.IsContiguous()); + TVM_FFI_ICHECK(logits.DataType() == DataType::Float(32)) << "Logits data type is not float32!"; + TVM_FFI_ICHECK(token_ids.DataType() == DataType::Int(32)) << "token ids must be int32!"; + TVM_FFI_ICHECK(logits->device.device_type == kDLCPU) << "logits device must be CPU!"; + TVM_FFI_ICHECK(token_ids->device.device_type == kDLCPU) << "token_ids device must be CPU!"; float* logits_raw_data = static_cast(logits->data); int* token_ids_data = static_cast(token_ids->data); size_t num_token_ids = token_ids->shape[token_ids->ndim - 1]; @@ -600,15 +603,15 @@ TVM_FFI_STATIC_INIT_BLOCK() { void ApplyPresenceAndFrequencyPenalty(Tensor logits, Tensor token_ids, Tensor token_freqs, double presence_penalty, double frequency_penalty) { // See https://platform.openai.com/docs/guides/text-generation/frequency-and-presence-penalties - ICHECK(logits.IsContiguous()); - ICHECK(token_ids.IsContiguous()); - ICHECK(token_freqs.IsContiguous()); - ICHECK(logits.DataType() == DataType::Float(32)) << "Logits data type is not float32!"; - ICHECK(token_ids.DataType() == DataType::Int(32)) << "token ids must be int32!"; - ICHECK(token_freqs.DataType() == DataType::Int(32)) << "token freqs must be int32!"; - ICHECK(logits->device.device_type == kDLCPU) << "logits device must be CPU!"; - ICHECK(token_ids->device.device_type == kDLCPU) << "token_ids device must be CPU!"; - ICHECK(token_freqs->device.device_type == kDLCPU) << "token_ids device must be CPU!"; + TVM_FFI_ICHECK(logits.IsContiguous()); + TVM_FFI_ICHECK(token_ids.IsContiguous()); + TVM_FFI_ICHECK(token_freqs.IsContiguous()); + TVM_FFI_ICHECK(logits.DataType() == DataType::Float(32)) << "Logits data type is not float32!"; + TVM_FFI_ICHECK(token_ids.DataType() == DataType::Int(32)) << "token ids must be int32!"; + TVM_FFI_ICHECK(token_freqs.DataType() == DataType::Int(32)) << "token freqs must be int32!"; + TVM_FFI_ICHECK(logits->device.device_type == kDLCPU) << "logits device must be CPU!"; + TVM_FFI_ICHECK(token_ids->device.device_type == kDLCPU) << "token_ids device must be CPU!"; + TVM_FFI_ICHECK(token_freqs->device.device_type == kDLCPU) << "token_ids device must be CPU!"; float* logits_raw_data = static_cast(logits->data); int* token_ids_data = static_cast(token_ids->data); @@ -629,9 +632,9 @@ TVM_FFI_STATIC_INIT_BLOCK() { // This is an inplace operation. void ApplySoftmaxWithTemperature(Tensor logits, double temperature) { - ICHECK(logits.IsContiguous()); - ICHECK(logits.DataType() == DataType::Float(32)) << "Logits data type is not float32!"; - ICHECK(logits->device.device_type == kDLCPU) << "logits device must be CPU!"; + TVM_FFI_ICHECK(logits.IsContiguous()); + TVM_FFI_ICHECK(logits.DataType() == DataType::Float(32)) << "Logits data type is not float32!"; + TVM_FFI_ICHECK(logits->device.device_type == kDLCPU) << "logits device must be CPU!"; int vocab_size = logits->shape[logits->ndim - 1]; float* logits_raw_data = static_cast(logits->data); float inv_temp = 1.0f / temperature; diff --git a/src/runtime/vm/paged_kv_cache.cc b/src/runtime/vm/paged_kv_cache.cc index 4fb3cd69d60f..36f7697237e2 100644 --- a/src/runtime/vm/paged_kv_cache.cc +++ b/src/runtime/vm/paged_kv_cache.cc @@ -340,22 +340,22 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { device_(device) { // Note: For MLA, sliding window and disaggregation are disabled for now. if (std::find(attn_kinds_.begin(), attn_kinds_.end(), AttnKind::kMLA) != attn_kinds_.end()) { - CHECK(!support_sliding_window_) << "Sliding window not supported yet for MLA"; - CHECK(!enable_kv_transfer) << "KV transfer not supported yet for MLA"; + TVM_FFI_ICHECK(!support_sliding_window_) << "Sliding window not supported yet for MLA"; + TVM_FFI_ICHECK(!enable_kv_transfer) << "KV transfer not supported yet for MLA"; } pages_.reserve(num_layers); if (enable_kv_transfer) { // For now, KV transfer only supports MHA. for (AttnKind attn_kind : attn_kinds_) { - CHECK(attn_kind == AttnKind::kMHA); + TVM_FFI_ICHECK(attn_kind == AttnKind::kMHA); } const auto f_nvshmem_init = tvm::ffi::Function::GetGlobal("runtime.disco.nvshmem.init_nvshmem"); - CHECK(f_nvshmem_init.has_value()) + TVM_FFI_ICHECK(f_nvshmem_init.has_value()) << "NVSHMEM is not enabled. Please make sure NVSHMEM is enabled when compiling TVM."; const auto f_nvshmem_empty = tvm::ffi::Function::GetGlobal("runtime.disco.nvshmem.empty"); - ICHECK(f_nvshmem_empty.has_value()); + TVM_FFI_ICHECK(f_nvshmem_empty.has_value()); nvshmem_pages_ = (*f_nvshmem_empty)( ffi::Shape({num_layers, num_total_pages, 2, num_kv_heads, page_size, qk_head_dim}), @@ -371,8 +371,8 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { const auto f_transfer_kv_ptr = tvm::ffi::Function::GetGlobal("nvshmem.KVTransfer"); const auto f_transfer_kv_page_to_page_ptr = tvm::ffi::Function::GetGlobal("nvshmem.KVTransferPageToPage"); - ICHECK(f_transfer_kv_ptr.has_value()); - ICHECK(f_transfer_kv_page_to_page_ptr.has_value()); + TVM_FFI_ICHECK(f_transfer_kv_ptr.has_value()); + TVM_FFI_ICHECK(f_transfer_kv_page_to_page_ptr.has_value()); f_transfer_kv_ = *f_transfer_kv_ptr; f_transfer_kv_page_to_page_ = *f_transfer_kv_page_to_page_ptr; } else { @@ -513,7 +513,7 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { // Right now only the "normal" RoPE mode supports the RoPE extention factors. if (rope_ext_factors_.defined()) { - CHECK(rope_mode_ == RoPEMode::kNormal) + TVM_FFI_ICHECK(rope_mode_ == RoPEMode::kNormal) << "The RoPE mode must be normal to support RoPE extension factors."; } } @@ -543,7 +543,7 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { /************** Sequence Management **************/ void AddSequence(int64_t seq_id) final { - CHECK(seq_map_.find(seq_id) == seq_map_.end()) + TVM_FFI_ICHECK(seq_map_.find(seq_id) == seq_map_.end()) << "The sequence \"" << seq_id << "\" is already in the KV cache."; int32_t block_idx = GetFreeBlock(); seq_map_.insert({seq_id, Sequence(&global_block_pool_, block_idx)}); @@ -552,10 +552,11 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { void RemoveSequence(int64_t seq_id) final { auto it = seq_map_.find(seq_id); - CHECK(it != seq_map_.end()) << "The sequence \"" << seq_id << "\" cannot be found in KV cache."; + TVM_FFI_ICHECK(it != seq_map_.end()) + << "The sequence \"" << seq_id << "\" cannot be found in KV cache."; int32_t block_idx = it->second.last_block_idx; // The block should have at least one reference, which comes from the sequence. - ICHECK_GE(global_block_pool_[block_idx].external_ref_cnt, 1); + TVM_FFI_ICHECK_GE(global_block_pool_[block_idx].external_ref_cnt, 1); while (block_idx != -1 && global_block_pool_[block_idx].external_ref_cnt == 1) { // - Free pages in the last block. for (int32_t page_id : global_block_pool_[block_idx].page_ids) { @@ -566,7 +567,7 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { } // - Decrease the external reference of the parent block. if (block_idx != -1) { - ICHECK_GT(global_block_pool_[block_idx].external_ref_cnt, 1); + TVM_FFI_ICHECK_GT(global_block_pool_[block_idx].external_ref_cnt, 1); --global_block_pool_[block_idx].external_ref_cnt; } seq_map_.erase(it); @@ -575,15 +576,15 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { void ForkSequence(int64_t parent_seq_id, int64_t child_seq_id, int64_t fork_pos = -1) final { auto parent_it = seq_map_.find(parent_seq_id); - CHECK(parent_it != seq_map_.end()) + TVM_FFI_ICHECK(parent_it != seq_map_.end()) << "The parent sequence \"" << parent_seq_id << "\" cannot be found in KV cache."; - CHECK(seq_map_.find(child_seq_id) == seq_map_.end()) + TVM_FFI_ICHECK(seq_map_.find(child_seq_id) == seq_map_.end()) << "The child sequence \"" << child_seq_id << "\" is already in the KV cache."; - CHECK_GE(fork_pos, -1) + TVM_FFI_ICHECK_GE(fork_pos, -1) << "The forked position should be non-negative, or -1 for last position as default."; - CHECK_LE(fork_pos, parent_it->second.seq_length) + TVM_FFI_ICHECK_LE(fork_pos, parent_it->second.seq_length) << "The forked position should not exceed the total length of parent sequence."; - CHECK(parent_it->second.accepted_indices_committed) + TVM_FFI_ICHECK(parent_it->second.accepted_indices_committed) << "The parent sequence's token tree computed in the last round of forward has not been " "committed with accepted nodes."; @@ -597,7 +598,7 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { const Sequence& seq = parent_it->second; int32_t sink_size = seq.seq_length - global_block_pool_[seq.last_block_idx].seq_length + seq.last_block_attn_sink_size; - CHECK_LE(fork_pos, sink_size) + TVM_FFI_ICHECK_LE(fork_pos, sink_size) << "The parent sequence \"" << parent_seq_id << "\" is enabled with sliding window and thus only can be forked within sink size = " << sink_size << ". But the forked position = " << fork_pos << "."; @@ -620,8 +621,8 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { int64_t in_block_offset = fork_pos; for (int32_t forked_block_idx : trace) { if (forked_block_idx != trace.back()) { - CHECK_GT(global_block_pool_[forked_block_idx].seq_length, 0); - CHECK_EQ(global_block_pool_[forked_block_idx].seq_length % page_size_, 0); + TVM_FFI_ICHECK_GT(global_block_pool_[forked_block_idx].seq_length, 0); + TVM_FFI_ICHECK_EQ(global_block_pool_[forked_block_idx].seq_length % page_size_, 0); if (global_block_pool_[forked_block_idx].seq_length <= in_block_offset) { in_block_offset -= global_block_pool_[forked_block_idx].seq_length; continue; @@ -667,7 +668,7 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { // last block if (parent_it->second.sliding_window_size != -1 && forked_block_idx == parent_it->second.last_block_idx) { - CHECK_LE(moved_offset, parent_it->second.last_block_attn_sink_size); + TVM_FFI_ICHECK_LE(moved_offset, parent_it->second.last_block_attn_sink_size); parent_it->second.last_block_attn_sink_size -= moved_offset; } } @@ -705,7 +706,7 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { void CompactKVCopy() { int total_copy_length = commit_copy_length_indptr_host_.back(); - ICHECK_GE(total_copy_length, 0); + TVM_FFI_ICHECK_GE(total_copy_length, 0); if (total_copy_length == 0) { return; } @@ -724,7 +725,7 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { // Set the copy stream for copy. DeviceAPI::Get(device_)->SetStream(device_, copy_stream_); } - ICHECK(f_compact_copy_.defined()) << "Function \"f_compact_copy\" is not defined."; + TVM_FFI_ICHECK(f_compact_copy_.defined()) << "Function \"f_compact_copy\" is not defined."; for (int layer = 0; layer < num_layers_; ++layer) { f_compact_copy_(pages_[layer], commit_copy_length_indptr_view, commit_copy_src_dst_pos_in_page_table_view, cur_batch_size_); @@ -742,24 +743,26 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { void EnableSlidingWindowForSeq(int64_t seq_id, int32_t sliding_window_size, int32_t attn_sink_size) final { // If per layer sliding window exists, enable sliding window for sequence - CHECK(support_sliding_window_ || support_layer_sliding_window_) + TVM_FFI_ICHECK(support_sliding_window_ || support_layer_sliding_window_) << "The KV cache does not support sliding window."; auto it = seq_map_.find(seq_id); - CHECK(it != seq_map_.end()) << "The sequence \"" << seq_id << "\" cannot be found in KV cache."; - CHECK_GE(attn_sink_size, 0) + TVM_FFI_ICHECK(it != seq_map_.end()) + << "The sequence \"" << seq_id << "\" cannot be found in KV cache."; + TVM_FFI_ICHECK_GE(attn_sink_size, 0) << "The specified attention sink size is expected to be non negative"; - CHECK_GT(sliding_window_size, 0) << "The specified sliding window size should be positive."; - CHECK_LT(attn_sink_size, sliding_window_size) + TVM_FFI_ICHECK_GT(sliding_window_size, 0) + << "The specified sliding window size should be positive."; + TVM_FFI_ICHECK_LT(attn_sink_size, sliding_window_size) << "The attn sink size should be less than the sliding window size."; // Set the sliding window flag of the sequence. - CHECK_EQ(it->second.sliding_window_size, -1) + TVM_FFI_ICHECK_EQ(it->second.sliding_window_size, -1) << "A sequence cannot be enabled twice for sliding window."; // Compute the total length of the prefix blocks of this sequence. const Block& last_block = global_block_pool_[it->second.last_block_idx]; int32_t prefix_length = it->second.seq_length - last_block.seq_length; - ICHECK_GE(prefix_length, 0); + TVM_FFI_ICHECK_GE(prefix_length, 0); // Since the prefix blocks cannot sliding, they are natural // attention sinks here. When the prefix length is already // larger than the specified attn sink size, we do not want to @@ -770,10 +773,11 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { void PopN(int64_t seq_id, int32_t n) final { auto it = seq_map_.find(seq_id); - CHECK(it != seq_map_.end()) << "The sequence \"" << seq_id << "\" cannot be found in KV cache."; + TVM_FFI_ICHECK(it != seq_map_.end()) + << "The sequence \"" << seq_id << "\" cannot be found in KV cache."; - CHECK_GE(n, 0) << "The length of popping " << n << " cannot be negative."; - CHECK_LE(n, it->second.seq_length) + TVM_FFI_ICHECK_GE(n, 0) << "The length of popping " << n << " cannot be negative."; + TVM_FFI_ICHECK_LE(n, it->second.seq_length) << "The sequence only has length " << it->second.seq_length << ", while the length of pop is " << n << " which exceeds the whole sequence length."; if (n == 0) { @@ -782,7 +786,7 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { int32_t block_idx = it->second.last_block_idx; // The block should have at least one reference, which comes from the sequence. - ICHECK_GE(global_block_pool_[block_idx].external_ref_cnt, 1); + TVM_FFI_ICHECK_GE(global_block_pool_[block_idx].external_ref_cnt, 1); while (block_idx != -1 && global_block_pool_[block_idx].external_ref_cnt == 1) { if (n > global_block_pool_[block_idx].seq_length) { n -= global_block_pool_[block_idx].seq_length; @@ -815,11 +819,11 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { // We use a temporary sequence id for fork. // This temporary seq id will immediately end its effect outside this function. int64_t temp_seq_id = -1 - seq_id; - CHECK(seq_map_.find(temp_seq_id) == seq_map_.end()); + TVM_FFI_ICHECK(seq_map_.find(temp_seq_id) == seq_map_.end()); ForkSequence(seq_id, temp_seq_id, it->second.seq_length - n); - CHECK(seq_map_.find(temp_seq_id) != seq_map_.end()); + TVM_FFI_ICHECK(seq_map_.find(temp_seq_id) != seq_map_.end()); RemoveSequence(seq_id); - CHECK(seq_map_.find(seq_id) == seq_map_.end()); + TVM_FFI_ICHECK(seq_map_.find(seq_id) == seq_map_.end()); auto it = seq_map_.find(temp_seq_id); seq_map_.insert({seq_id, it->second}); seq_map_.erase(temp_seq_id); @@ -852,10 +856,11 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { const ffi::Optional& opt_token_tree_parent_ptr) final { // Note: MLA does not supported tree attention for now. if (attn_kinds_[0] == AttnKind::kMLA) { - CHECK(!opt_token_tree_parent_ptr.defined()) << "Tree attention is not supported yet for MLA"; + TVM_FFI_ICHECK(!opt_token_tree_parent_ptr.defined()) + << "Tree attention is not supported yet for MLA"; } - CHECK_EQ(seq_ids.size(), append_lengths.size()) + TVM_FFI_ICHECK_EQ(seq_ids.size(), append_lengths.size()) << "The seq_ids size (" << seq_ids.size() << ") and append_lengths size (" << append_lengths.size() << ") mismatch."; cur_batch_size_ = seq_ids.size(); @@ -871,8 +876,8 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { k_ragged_rope_pos_offset_host_.clear(); for (int i = 0; i < cur_batch_size_; ++i) { auto it = seq_map_.find(seq_ids[i]); - CHECK(it != seq_map_.end()) << "The sequence \"" << seq_ids[i] - << "\" cannot be found in KV cache."; + TVM_FFI_ICHECK(it != seq_map_.end()) + << "The sequence \"" << seq_ids[i] << "\" cannot be found in KV cache."; sequences.push_back(&it->second); last_block_length_before_append.push_back( global_block_pool_[it->second.last_block_idx].seq_length); @@ -892,7 +897,7 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { GetBlockIdsOnDepth(sequences, global_block_pool_, cur_batch_size_); num_depths_ = std::min(static_cast(block_ids_on_depths.size()), kPagedKVCacheMaxBlockDepth); - ICHECK_LE(num_depths_, kPagedKVCacheMaxBlockDepth); + TVM_FFI_ICHECK_LE(num_depths_, kPagedKVCacheMaxBlockDepth); std::vector>> chunked_block_ids_arr; chunked_block_ids_arr.reserve(num_depths_); @@ -910,7 +915,7 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { if (num_depths_ == kPagedKVCacheMaxBlockDepth) { // Since we force the blocks at maximum depth not to coalesce, the output blocks at maximum // depth must have the same size as current batch. - CHECK_EQ(chunked_block_ids_arr[num_depths_ - 1].size(), cur_batch_size_); + TVM_FFI_ICHECK_EQ(chunked_block_ids_arr[num_depths_ - 1].size(), cur_batch_size_); } append_before_attn_ = !support_sliding_window_ && use_decode_kernel_.back(); @@ -931,8 +936,9 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { // - Check token tree validity and process the token tree. if (opt_token_tree_parent_ptr.defined()) { - CHECK(!support_sliding_window_) << "Tree attention does not support sliding window."; - CHECK(rope_mode_ != RoPEMode::kInline) << "Tree attention does not support inline RoPE mode."; + TVM_FFI_ICHECK(!support_sliding_window_) << "Tree attention does not support sliding window."; + TVM_FFI_ICHECK(rope_mode_ != RoPEMode::kInline) + << "Tree attention does not support inline RoPE mode."; ConstructTokenTreeMask(sequences, opt_token_tree_parent_ptr.value(), block_ids_on_depths, trailing_blocks); } else { @@ -940,7 +946,7 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { // is required to have all past accepted tokens committed. for (int i = 0; i < cur_batch_size_; ++i) { Sequence* sequence = sequences[i]; - CHECK(sequence->accepted_indices_committed) + TVM_FFI_ICHECK(sequence->accepted_indices_committed) << "The input batch does not form a tree, in which case the sequences in the input " "batch are expected to have their accepted tokens token tree nodes committed. " "Please invoke CommitAcceptedTokenTreeNodes for sequence " @@ -1129,7 +1135,7 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { } else { int64_t offset_in_tree = static_cast(sequences[i]->token_tree_parent_ptr.size()) - append_length; - ICHECK_GE(offset_in_tree, 0); + TVM_FFI_ICHECK_GE(offset_in_tree, 0); q_rope_position_map_host_.push_back( k_ragged_rope_pos_offset_host_[i] + sequences[i]->token_tree_node_depths[offset_in_tree + pos]); @@ -1198,7 +1204,7 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { // Compression format: [n, begin_1, length_1, begin_2, length_2, ..., begin_n, length_n] // The compressed format will be decompressed to: // [begin_1, begin_1+1, ..., begin_1+length_1-1, ..., begin_n, ..., begin_n+length_n-1] - CHECK_EQ(append_position_map_host_.size(), append_length); + TVM_FFI_ICHECK_EQ(append_position_map_host_.size(), append_length); std::vector compressed_append_pos_map{/*num_segments=*/1, append_position_map_host_[0]}; for (int i = 1; i < append_length; ++i) { @@ -1215,15 +1221,16 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { compressed_append_pos_map.push_back(append_position_map_host_.back() - compressed_append_pos_map.back() + 1); // The compressed array size should be "num_segments * 2 + 1". - CHECK_EQ(compressed_append_pos_map.size(), compressed_append_pos_map[0] * 2 + 1); + TVM_FFI_ICHECK_EQ(compressed_append_pos_map.size(), compressed_append_pos_map[0] * 2 + 1); return ffi::Shape{compressed_append_pos_map}; } void DisaggMarkSend(int64_t seq_id, int64_t begin, const ffi::Shape& compressed_remote_position_map, int32_t recver_pe_offset) { - ICHECK(f_transfer_kv_.defined()); + TVM_FFI_ICHECK(f_transfer_kv_.defined()); auto it = seq_map_.find(seq_id); - CHECK(it != seq_map_.end()) << "The sequence \"" << seq_id << "\" cannot be found in KV cache."; + TVM_FFI_ICHECK(it != seq_map_.end()) + << "The sequence \"" << seq_id << "\" cannot be found in KV cache."; Sequence* sequence = &it->second; sequence->kv_transfer_metadata.start = begin; int nsegments = compressed_remote_position_map[0]; @@ -1242,8 +1249,8 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { return; } // Need to send existing KV. - CHECK_GT(static_cast(sequence->kv_transfer_metadata.remote_position_map.size()), - sequence->seq_length - begin) + TVM_FFI_ICHECK_GT(static_cast(sequence->kv_transfer_metadata.remote_position_map.size()), + sequence->seq_length - begin) << "Need at least one token to prefill"; std::vector trace = sequence->GetBlockTrace(global_block_pool_); sequence->kv_transfer_metadata.local_position_map.reserve(sequence->seq_length - begin); @@ -1275,38 +1282,38 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { Tensor o_data, double sm_scale) final { // Part 1. Shape and dtype check. int64_t local_layer_id = layer_id - layer_id_begin_offset_; - CHECK_GE(local_layer_id, 0); - CHECK_LT(local_layer_id, num_layers_); + TVM_FFI_ICHECK_GE(local_layer_id, 0); + TVM_FFI_ICHECK_LT(local_layer_id, num_layers_); Tensor pages = pages_[local_layer_id]; - CHECK(qkv_data.DataType() == pages.DataType()); - CHECK(o_data.DataType() == pages.DataType()); - CHECK(attn_kinds_[layer_id] == AttnKind::kMHA || - attn_kinds_[layer_id] == AttnKind::kMHASliding); + TVM_FFI_ICHECK(qkv_data.DataType() == pages.DataType()); + TVM_FFI_ICHECK(o_data.DataType() == pages.DataType()); + TVM_FFI_ICHECK(attn_kinds_[layer_id] == AttnKind::kMHA || + attn_kinds_[layer_id] == AttnKind::kMHASliding); // qkv_data: (num_total_length, num_qo_heads + 2 * num_kv_heads, qk_head_dim) // o_data: (num_total_length, num_qo_heads, qk_head_dim) - CHECK_EQ(qkv_data->ndim, 3); - CHECK_EQ(o_data->ndim, 3); + TVM_FFI_ICHECK_EQ(qkv_data->ndim, 3); + TVM_FFI_ICHECK_EQ(o_data->ndim, 3); for (int dim = 0; dim < 3; ++dim) { if (dim == 1) { - CHECK_EQ(qkv_data->shape[1], num_qo_heads_ + 2 * num_kv_heads_); - CHECK_EQ(o_data->shape[1], num_qo_heads_); + TVM_FFI_ICHECK_EQ(qkv_data->shape[1], num_qo_heads_ + 2 * num_kv_heads_); + TVM_FFI_ICHECK_EQ(o_data->shape[1], num_qo_heads_); } else { - CHECK_EQ(o_data->shape[dim], qkv_data->shape[dim]); + TVM_FFI_ICHECK_EQ(o_data->shape[dim], qkv_data->shape[dim]); } } - CHECK_EQ(qkv_data->shape[2], qk_head_dim_); + TVM_FFI_ICHECK_EQ(qkv_data->shape[2], qk_head_dim_); int64_t total_seq_length = 0; for (int64_t seq_id = 0; seq_id < cur_batch_size_; ++seq_id) { total_seq_length += cur_append_lengths_[seq_id]; } - CHECK_LE(total_seq_length, qkv_data->shape[0]); + TVM_FFI_ICHECK_LE(total_seq_length, qkv_data->shape[0]); // Sync the copy stream and the compute stream. ComputeStreamWaitForCopyStream(); // The auxiliary data structure on device must have been synchronized. - ICHECK(!dirty_aux_data_device_); + TVM_FFI_ICHECK(!dirty_aux_data_device_); Tensor q_data = temp_attn_q_device_.CreateView({total_seq_length, num_qo_heads_, qk_head_dim_}, qkv_data->dtype); @@ -1337,7 +1344,7 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { } // Part 3. Append k/v data to kv-cache if flag "append_before_attn" is set. - CHECK(f_transpose_append_mha_.defined()); + TVM_FFI_ICHECK(f_transpose_append_mha_.defined()); if (append_before_attn_) { f_transpose_append_mha_.value()(pages_[local_layer_id], k_data, v_data, append_position_map_view_); @@ -1376,13 +1383,13 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { Tensor lse_data, double sm_scale) final { // Shape and dtype check. int64_t local_layer_id = layer_id - layer_id_begin_offset_; - CHECK_GE(local_layer_id, 0); - CHECK_LT(local_layer_id, num_layers_); + TVM_FFI_ICHECK_GE(local_layer_id, 0); + TVM_FFI_ICHECK_LT(local_layer_id, num_layers_); Tensor pages = pages_[local_layer_id]; - CHECK(q_data.DataType() == pages.DataType()); - CHECK(k_data.DataType() == pages.DataType()); - CHECK(v_data.DataType() == pages.DataType()); - CHECK(o_data.DataType() == pages.DataType()); + TVM_FFI_ICHECK(q_data.DataType() == pages.DataType()); + TVM_FFI_ICHECK(k_data.DataType() == pages.DataType()); + TVM_FFI_ICHECK(v_data.DataType() == pages.DataType()); + TVM_FFI_ICHECK(o_data.DataType() == pages.DataType()); AttnKind attn_kind = attn_kinds_[layer_id]; // q_data: (num_total_length, num_qo_heads, qk_head_dim) @@ -1394,19 +1401,19 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { for (int64_t seq_id = 0; seq_id < cur_batch_size_; ++seq_id) { total_seq_length += cur_append_lengths_[seq_id]; } - CHECK_EQ(q_data->ndim, 3); - CHECK_EQ(k_data->ndim, 3); - CHECK_EQ(v_data->ndim, 3); - CHECK_EQ(o_data->ndim, 3); - CHECK_EQ(q_data->shape[0], total_seq_length); - CHECK_EQ(k_data->shape[0], total_seq_length); - CHECK_EQ(v_data->shape[0], total_seq_length); - CHECK_EQ(o_data->shape[0], total_seq_length); + TVM_FFI_ICHECK_EQ(q_data->ndim, 3); + TVM_FFI_ICHECK_EQ(k_data->ndim, 3); + TVM_FFI_ICHECK_EQ(v_data->ndim, 3); + TVM_FFI_ICHECK_EQ(o_data->ndim, 3); + TVM_FFI_ICHECK_EQ(q_data->shape[0], total_seq_length); + TVM_FFI_ICHECK_EQ(k_data->shape[0], total_seq_length); + TVM_FFI_ICHECK_EQ(v_data->shape[0], total_seq_length); + TVM_FFI_ICHECK_EQ(o_data->shape[0], total_seq_length); // Sync the copy stream and the compute stream. ComputeStreamWaitForCopyStream(); // The auxiliary data structure on device must have been synchronized. - ICHECK(!dirty_aux_data_device_); + TVM_FFI_ICHECK(!dirty_aux_data_device_); if (attn_kind == AttnKind::kMHA) { MHASelfAttnInternal(q_data, k_data, v_data, o_data, lse_data, sm_scale); @@ -1419,11 +1426,11 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { double sm_scale) final { // Shape and dtype check. int64_t local_layer_id = layer_id - layer_id_begin_offset_; - CHECK_GE(local_layer_id, 0); - CHECK_LT(local_layer_id, num_layers_); + TVM_FFI_ICHECK_GE(local_layer_id, 0); + TVM_FFI_ICHECK_LT(local_layer_id, num_layers_); Tensor pages = pages_[local_layer_id]; - CHECK(q_data.DataType() == pages.DataType()); - CHECK(o_data.DataType() == pages.DataType()); + TVM_FFI_ICHECK(q_data.DataType() == pages.DataType()); + TVM_FFI_ICHECK(o_data.DataType() == pages.DataType()); AttnKind attn_kind = attn_kinds_[layer_id]; // q_data: (num_total_length, num_qo_heads, qk_head_dim) @@ -1433,19 +1440,19 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { for (int64_t seq_id = 0; seq_id < cur_batch_size_; ++seq_id) { total_seq_length += cur_append_lengths_[seq_id]; } - CHECK_EQ(q_data->ndim, 3); - CHECK_EQ(o_data->ndim, 3); - CHECK_EQ(q_data->shape[0], total_seq_length); - CHECK_EQ(o_data->shape[0], total_seq_length); - CHECK_EQ(q_data->shape[1], num_qo_heads_); - CHECK_EQ(o_data->shape[1], num_qo_heads_); - CHECK_EQ(q_data->shape[2], qk_head_dim_); - CHECK_EQ(o_data->shape[2], v_head_dim_); + TVM_FFI_ICHECK_EQ(q_data->ndim, 3); + TVM_FFI_ICHECK_EQ(o_data->ndim, 3); + TVM_FFI_ICHECK_EQ(q_data->shape[0], total_seq_length); + TVM_FFI_ICHECK_EQ(o_data->shape[0], total_seq_length); + TVM_FFI_ICHECK_EQ(q_data->shape[1], num_qo_heads_); + TVM_FFI_ICHECK_EQ(o_data->shape[1], num_qo_heads_); + TVM_FFI_ICHECK_EQ(q_data->shape[2], qk_head_dim_); + TVM_FFI_ICHECK_EQ(o_data->shape[2], v_head_dim_); // Sync the copy stream and the compute stream. ComputeStreamWaitForCopyStream(); // The auxiliary data structure on device must have been synchronized. - ICHECK(!dirty_aux_data_device_); + TVM_FFI_ICHECK(!dirty_aux_data_device_); if (attn_kind == AttnKind::kMHA) { MHACrossAttnInternal(local_layer_id, q_data, o_data, lse_data, sm_scale, @@ -1458,32 +1465,33 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { void AppendMLAKV(int64_t layer_id, Tensor kv_data) final { // Shape and dtype check. int64_t local_layer_id = layer_id - layer_id_begin_offset_; - CHECK_GE(local_layer_id, 0); - CHECK_LT(local_layer_id, num_layers_); + TVM_FFI_ICHECK_GE(local_layer_id, 0); + TVM_FFI_ICHECK_LT(local_layer_id, num_layers_); Tensor pages = pages_[local_layer_id]; - CHECK(kv_data.DataType() == pages.DataType()); - CHECK(attn_kinds_[layer_id] == AttnKind::kMLA); + TVM_FFI_ICHECK(kv_data.DataType() == pages.DataType()); + TVM_FFI_ICHECK(attn_kinds_[layer_id] == AttnKind::kMLA); // kv_data: (num_total_length, qk_head_dim) - CHECK_EQ(kv_data->ndim, 2); + TVM_FFI_ICHECK_EQ(kv_data->ndim, 2); int64_t total_seq_length = 0; for (int64_t seq_id = 0; seq_id < cur_batch_size_; ++seq_id) { total_seq_length += cur_append_lengths_[seq_id]; } - CHECK_LE(kv_data->shape[0], total_seq_length); - CHECK_EQ(kv_data->shape[1], qk_head_dim_); + TVM_FFI_ICHECK_LE(kv_data->shape[0], total_seq_length); + TVM_FFI_ICHECK_EQ(kv_data->shape[1], qk_head_dim_); // Sync the copy stream and the compute stream. ComputeStreamWaitForCopyStream(); // The auxiliary data structure on device must have been synchronized. - ICHECK(!dirty_aux_data_device_); + TVM_FFI_ICHECK(!dirty_aux_data_device_); - CHECK(f_transpose_append_mla_.defined()); + TVM_FFI_ICHECK(f_transpose_append_mla_.defined()); f_transpose_append_mla_.value()(pages_[local_layer_id], kv_data, append_position_map_view_); } ffi::Array MergeAttnOutputInplace(Tensor o_self_attn, Tensor lse_self_attn, Tensor o_cross_attn, Tensor lse_cross_attn) final { - CHECK_GE(f_merge_inplace_.size(), 2) << "The general attention merge function is not defined."; + TVM_FFI_ICHECK_GE(f_merge_inplace_.size(), 2) + << "The general attention merge function is not defined."; f_merge_inplace_[1](o_self_attn, lse_self_attn, o_cross_attn, lse_cross_attn); return {o_self_attn, lse_self_attn}; } @@ -1495,7 +1503,7 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { void CommitAcceptedTokenTreeNodes(const ffi::Shape& seq_ids, const ffi::Shape& leaf_indices) final { - CHECK_EQ(seq_ids.size(), leaf_indices.size()) + TVM_FFI_ICHECK_EQ(seq_ids.size(), leaf_indices.size()) << "The given seq_ids and leaf_indices have different size."; int num_seq_to_commit = seq_ids.size(); @@ -1504,15 +1512,16 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { bool is_chain = true; for (int i = 0; i < num_seq_to_commit; ++i) { auto it = seq_map_.find(seq_ids[i]); - CHECK(it != seq_map_.end()) << "The sequence \"" << seq_ids[i] - << "\" cannot be found in KV cache."; + TVM_FFI_ICHECK(it != seq_map_.end()) + << "The sequence \"" << seq_ids[i] << "\" cannot be found in KV cache."; sequences.push_back(&it->second); is_chain = it->second.is_chain; - CHECK(leaf_indices[i] == -1 || !it->second.accepted_indices_committed) + TVM_FFI_ICHECK(leaf_indices[i] == -1 || !it->second.accepted_indices_committed) << "The accepted nodes of sequence " << seq_ids[i] << " are already committed."; - CHECK_GE(leaf_indices[i], -1) + TVM_FFI_ICHECK_GE(leaf_indices[i], -1) << "Invalid tree index " << leaf_indices[i] << " which is less than -1"; - CHECK_LT(leaf_indices[i], static_cast(it->second.token_tree_parent_ptr.size())) + TVM_FFI_ICHECK_LT(leaf_indices[i], + static_cast(it->second.token_tree_parent_ptr.size())) << "Invalid tree index " << leaf_indices[i] << " which is larger than or equals to the append length " << it->second.token_tree_parent_ptr.size() << " of the sequence"; @@ -1539,7 +1548,8 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { path_on_tree.push_back(node); node = sequences[i]->token_tree_parent_ptr[node]; } - ICHECK_EQ(path_on_tree.size(), sequences[i]->token_tree_node_depths[leaf_indices[i]] + 1); + TVM_FFI_ICHECK_EQ(path_on_tree.size(), + sequences[i]->token_tree_node_depths[leaf_indices[i]] + 1); // Get the destination array (range [0, path_length - 1)) of KV cache copy. std::vector copy_dst_pos_in_seq; copy_dst_pos_in_seq.resize(path_on_tree.size()); @@ -1590,20 +1600,21 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { // Sync the copy stream and the compute stream. ComputeStreamWaitForCopyStream(); // The auxiliary data structure on device must have been synchronized. - ICHECK(!dirty_aux_data_device_); + TVM_FFI_ICHECK(!dirty_aux_data_device_); return q_rope_position_map_view_; }; void DebugGetKV(int64_t seq_id, int64_t start_pos, int64_t end_pos, Tensor k_data, Tensor v_data) final { - CHECK(f_debug_get_kv_.defined()) + TVM_FFI_ICHECK(f_debug_get_kv_.defined()) << "PageAttentionKVCache requires the `f_debug_get_kv` to be explicitly passed in when " "initialization. Please construct the KV cache with `f_debug_get_kv`."; const Sequence& seq = seq_map_.at(seq_id); - CHECK_GE(start_pos, 0) << "DebugGetKV does not accept negative start_pos " << start_pos; - CHECK_LE(end_pos, seq.seq_length) << "DebugGetKV does not accept out-of-range end_pos"; - CHECK_LT(start_pos, end_pos) << "DebugGetKV does not accept \"start_pos >= end_pos\""; + TVM_FFI_ICHECK_GE(start_pos, 0) + << "DebugGetKV does not accept negative start_pos " << start_pos; + TVM_FFI_ICHECK_LE(end_pos, seq.seq_length) << "DebugGetKV does not accept out-of-range end_pos"; + TVM_FFI_ICHECK_LT(start_pos, end_pos) << "DebugGetKV does not accept \"start_pos >= end_pos\""; // k/v_data: (num_layers, seq_length, num_kv_heads, qk_head_dim) static constexpr const char* error_msg = @@ -1611,14 +1622,14 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { "qk_head_dim)."; std::vector vec_kv_data = {&k_data, &v_data}; for (const Tensor* data_ptr : vec_kv_data) { - CHECK_EQ((*data_ptr)->ndim, 4) << error_msg; - CHECK_EQ((*data_ptr)->shape[0], num_layers_) + TVM_FFI_ICHECK_EQ((*data_ptr)->ndim, 4) << error_msg; + TVM_FFI_ICHECK_EQ((*data_ptr)->shape[0], num_layers_) << error_msg << " The number of layers mismatches."; - CHECK_EQ((*data_ptr)->shape[1], end_pos - start_pos) + TVM_FFI_ICHECK_EQ((*data_ptr)->shape[1], end_pos - start_pos) << error_msg << " The sequence length mismatches."; - CHECK_EQ((*data_ptr)->shape[2], num_kv_heads_) + TVM_FFI_ICHECK_EQ((*data_ptr)->shape[2], num_kv_heads_) << error_msg << " The number of heads mismatches."; - CHECK_EQ((*data_ptr)->shape[3], qk_head_dim_) + TVM_FFI_ICHECK_EQ((*data_ptr)->shape[3], qk_head_dim_) << error_msg << " The number of head features mismatches."; } @@ -1640,29 +1651,32 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { append_position_map.data() + start_pos, (end_pos - start_pos) * ((dtype_aux_.bits * dtype_aux_.lanes + 7) / 8)); for (int64_t layer_id = 0; layer_id < num_layers_; ++layer_id) { - CHECK(attn_kinds_[layer_id] == AttnKind::kMHA) << "Only MHA is supported for DebugGetKV"; + TVM_FFI_ICHECK(attn_kinds_[layer_id] == AttnKind::kMHA) + << "Only MHA is supported for DebugGetKV"; f_debug_get_kv_.value()(pages_[layer_id], position_map_device, k_data, v_data, layer_id); } } void DebugGetKVMLA(int64_t seq_id, int64_t start_pos, int64_t end_pos, Tensor kv_data) final { - CHECK(f_debug_get_kv_.defined()) + TVM_FFI_ICHECK(f_debug_get_kv_.defined()) << "PageAttentionKVCache requires the `f_debug_get_kv` to be explicitly passed in when " "initialization. Please construct the KV cache with `f_debug_get_kv`."; const Sequence& seq = seq_map_.at(seq_id); - CHECK_GE(start_pos, 0) << "DebugGetKV does not accept negative start_pos " << start_pos; - CHECK_LE(end_pos, seq.seq_length) << "DebugGetKV does not accept out-of-range end_pos"; - CHECK_LT(start_pos, end_pos) << "DebugGetKV does not accept \"start_pos >= end_pos\""; + TVM_FFI_ICHECK_GE(start_pos, 0) + << "DebugGetKV does not accept negative start_pos " << start_pos; + TVM_FFI_ICHECK_LE(end_pos, seq.seq_length) << "DebugGetKV does not accept out-of-range end_pos"; + TVM_FFI_ICHECK_LT(start_pos, end_pos) << "DebugGetKV does not accept \"start_pos >= end_pos\""; // kv_data: (num_layers, seq_length, qk_head_dim) static constexpr const char* error_msg = "DebugGetKV expects the kv_data in layout (num_layers, seq_length, qk_head_dim)."; - CHECK_EQ(kv_data->ndim, 3) << error_msg; - CHECK_EQ(kv_data->shape[0], num_layers_) << error_msg << " The number of layers mismatches."; - CHECK_EQ(kv_data->shape[1], end_pos - start_pos) + TVM_FFI_ICHECK_EQ(kv_data->ndim, 3) << error_msg; + TVM_FFI_ICHECK_EQ(kv_data->shape[0], num_layers_) + << error_msg << " The number of layers mismatches."; + TVM_FFI_ICHECK_EQ(kv_data->shape[1], end_pos - start_pos) << error_msg << " The sequence length mismatches."; - CHECK_EQ(kv_data->shape[2], qk_head_dim_) + TVM_FFI_ICHECK_EQ(kv_data->shape[2], qk_head_dim_) << error_msg << " The number of head features mismatches."; std::vector trace = seq.GetBlockTrace(global_block_pool_); @@ -1683,13 +1697,14 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { append_position_map.data() + start_pos, (end_pos - start_pos) * ((dtype_aux_.bits * dtype_aux_.lanes + 7) / 8)); for (int64_t layer_id = 0; layer_id < num_layers_; ++layer_id) { - CHECK(attn_kinds_[layer_id] == AttnKind::kMLA) << "Only MHA is supported for DebugGetKVMLA"; + TVM_FFI_ICHECK(attn_kinds_[layer_id] == AttnKind::kMLA) + << "Only MHA is supported for DebugGetKVMLA"; f_debug_get_kv_.value()(pages_[layer_id], position_map_device, kv_data, layer_id); } } void DebugSetKV(int64_t seq_id, int64_t start_pos, Tensor k_data, Tensor v_data) final { - ICHECK(false) << "DebugSetKV for PageAttentionKVCache not implemented yet."; + TVM_FFI_ICHECK(false) << "DebugSetKV for PageAttentionKVCache not implemented yet."; } TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.vm.PagedAttentionKVCache", PagedAttentionKVCacheObj, AttentionKVCacheObj); @@ -1698,7 +1713,7 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { /*! \brief Get a new free page and return its id. */ int32_t GetFreePage() { // Find a page from the free page pools. - CHECK(!free_page_ids_.empty()) << "The KV cache is full. No page can be allocated."; + TVM_FFI_ICHECK(!free_page_ids_.empty()) << "The KV cache is full. No page can be allocated."; int32_t page_id = free_page_ids_.back(); free_page_ids_.pop_back(); return page_id; @@ -1710,7 +1725,7 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { int32_t block_idx = free_block_idx_.back(); free_block_idx_.pop_back(); global_block_pool_[block_idx].Reset(); - ICHECK_EQ(global_block_pool_[block_idx].index, block_idx); + TVM_FFI_ICHECK_EQ(global_block_pool_[block_idx].index, block_idx); return block_idx; } @@ -1738,8 +1753,8 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { for (int d = 0; d < num_depths_; ++d) { // We check if the token tree deteriorates to a chain, // because chain cases can have simplified attention work flow. - ICHECK_LT(d, tree_attn_mask_host_.size()); - ICHECK_LT(d, tree_attn_mn_indptr_host_.size()); + TVM_FFI_ICHECK_LT(d, tree_attn_mask_host_.size()); + TVM_FFI_ICHECK_LT(d, tree_attn_mn_indptr_host_.size()); HostMemoryVector& tree_attn_mn_indptr = tree_attn_mn_indptr_host_[d]; HostMemoryVector& tree_attn_mask = tree_attn_mask_host_[d]; @@ -1752,8 +1767,8 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { bool is_chain = true; // - Construct the mn indptr array, which is the indptr of the mask size of each sequence. tree_attn_mn_indptr.push_back(0); - ICHECK_EQ(sequences.size(), cur_batch_size_); - ICHECK_EQ(cur_append_lengths_.size(), cur_batch_size_); + TVM_FFI_ICHECK_EQ(sequences.size(), cur_batch_size_); + TVM_FFI_ICHECK_EQ(cur_append_lengths_.size(), cur_batch_size_); int64_t token_tree_parent_ptr_offset = 0; for (int i = 0; i < cur_batch_size_; ++i) { int64_t append_length = cur_append_lengths_[i]; @@ -1764,21 +1779,21 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { continue; } // Update the token tree parent pointers. - CHECK_LE(sequences[i]->token_tree_parent_ptr.size(), - global_block_pool_[sequences[i]->last_block_idx].seq_length) + TVM_FFI_ICHECK_LE(sequences[i]->token_tree_parent_ptr.size(), + global_block_pool_[sequences[i]->last_block_idx].seq_length) << "The token tree size is larger than the sequence length of the last block."; std::copy(token_tree_parent_ptr.begin() + token_tree_parent_ptr_offset, token_tree_parent_ptr.begin() + token_tree_parent_ptr_offset + append_length, std::back_inserter(sequences[i]->token_tree_parent_ptr)); token_tree_parent_ptr_offset += append_length; - CHECK_LE(sequences[i]->token_tree_parent_ptr.size(), kTreeAttnMaxTreeSize) + TVM_FFI_ICHECK_LE(sequences[i]->token_tree_parent_ptr.size(), kTreeAttnMaxTreeSize) << "The tree size is " << append_length << " which exceeds the maximum tree size limit " << kTreeAttnMaxTreeSize; tree_attn_mn_indptr.push_back(tree_attn_mn_indptr.back() + sequences[i]->token_tree_parent_ptr.size()); } - CHECK_EQ(token_tree_parent_ptr.size(), token_tree_parent_ptr_offset) + TVM_FFI_ICHECK_EQ(token_tree_parent_ptr.size(), token_tree_parent_ptr_offset) << "Invalid token tree size. The sum of \"append_lengths\" is " << token_tree_parent_ptr_offset << " while there are " << token_tree_parent_ptr.size() << " elements in \"token_tree_parent_ptr\"."; @@ -1798,10 +1813,10 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { std::unordered_map> tree_parent_to_children; std::vector tree_roots; for (int n = 0; n < tree_size; ++n) { - CHECK_LT(sequences[i]->token_tree_parent_ptr[n], n) + TVM_FFI_ICHECK_LT(sequences[i]->token_tree_parent_ptr[n], n) << "Invalid token tree. The parent of node " << n << " in tree " << i << " is " << sequences[i]->token_tree_parent_ptr[n] << ", which is not smaller than " << n; - CHECK_GE(sequences[i]->token_tree_parent_ptr[n], -1) + TVM_FFI_ICHECK_GE(sequences[i]->token_tree_parent_ptr[n], -1) << "Invalid token tree. The parent of node " << n << " in tree " << i << " is " << sequences[i]->token_tree_parent_ptr[n]; if (sequences[i]->token_tree_parent_ptr[n] != n - 1) { @@ -1873,7 +1888,7 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { // and thus we set the sink length of the last block, the index of the // first sliding page, and starting offset in first sliding page. if (seq->last_block_attn_sink_size > 0 && block.sink_length == 0) { - ICHECK_EQ(block.sliding_window_offset, 0); + TVM_FFI_ICHECK_EQ(block.sliding_window_offset, 0); block.sink_length = seq->last_block_attn_sink_size; block.sliding_window_offset = seq->last_block_attn_sink_size; } @@ -1897,17 +1912,17 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { } // - The first sliding page after sliding is either the last sink page, // or the page next to the last sink page. - ICHECK(page_idx_after_sliding == num_sink_pages - 1 || - page_idx_after_sliding == num_sink_pages); + TVM_FFI_ICHECK(page_idx_after_sliding == num_sink_pages - 1 || + page_idx_after_sliding == num_sink_pages); // - Update the length of the sequence and the block. seq->seq_length = seq->sliding_window_size; block.seq_length -= length_to_slide; block.sliding_window_offset = page_idx_after_sliding * page_size_ + page_start_offset_after_sliding; - ICHECK_GE(block.seq_length, block.sink_length); - ICHECK_GE(block.sliding_window_offset, block.sink_length); - ICHECK_EQ( + TVM_FFI_ICHECK_GE(block.seq_length, block.sink_length); + TVM_FFI_ICHECK_GE(block.sliding_window_offset, block.sink_length); + TVM_FFI_ICHECK_EQ( (block.sliding_window_offset + (block.seq_length - block.sink_length) + page_size_ - 1) / page_size_, block.page_ids.size()); @@ -1926,8 +1941,8 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { void ReserveAppendLengthInSeq(Sequence* seq, int64_t append_length) { int32_t block_idx = seq->last_block_idx; Block& block = global_block_pool_[block_idx]; - CHECK_GT(append_length, 0) << "Append with length 0 is not allowed."; - CHECK_EQ(block.external_ref_cnt, 1) + TVM_FFI_ICHECK_GT(append_length, 0) << "Append with length 0 is not allowed."; + TVM_FFI_ICHECK_EQ(block.external_ref_cnt, 1) << "The block is " << block.external_ref_cnt - 1 << "-time referenced by other blocks, thus cannot accept new KV values."; @@ -2016,7 +2031,7 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { if (page_indices_on_depths_view_[d]->shape[0] == 0) { continue; } - CHECK(!support_sliding_window_ || !support_layer_sliding_window_) + TVM_FFI_ICHECK(!support_sliding_window_ || !support_layer_sliding_window_) << "Kernel BeginForward doesn't support sliding window."; if (use_decode_kernel_[d]) { if (f_attention_decode_ != nullptr && @@ -2061,7 +2076,8 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { if (page_indices_on_depths_view_[d]->shape[0] == 0) { continue; } - CHECK(!support_sliding_window_) << "Kernel BeginForward doesn't support sliding window."; + TVM_FFI_ICHECK(!support_sliding_window_) + << "Kernel BeginForward doesn't support sliding window."; if (f_mla_prefill_ != nullptr && f_mla_prefill_->backend_kind == AttnBackendKind::kFlashInfer) { f_mla_prefill_->BeginForward( @@ -2082,8 +2098,8 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { void AttentionInternal(int64_t layer_id, Tensor q_data, Tensor k_data, Tensor v_data, Tensor output, double sm_scale) { int64_t local_layer_id = layer_id - layer_id_begin_offset_; - CHECK_GE(local_layer_id, 0); - CHECK_LT(local_layer_id, num_layers_); + TVM_FFI_ICHECK_GE(local_layer_id, 0); + TVM_FFI_ICHECK_LT(local_layer_id, num_layers_); bool is_first_kernel = true; if (!append_before_attn_) { @@ -2094,7 +2110,7 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { bool self_attn_computed = !is_first_kernel; bool cross_attn_computed = MHACrossAttnInternal( local_layer_id, q_data, output, merged_attn_lse_view_, sm_scale, is_first_kernel); - CHECK(self_attn_computed || cross_attn_computed) + TVM_FFI_ICHECK(self_attn_computed || cross_attn_computed) << "Both self-attention and cross-attention are not computed."; } @@ -2102,17 +2118,17 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { Tensor lse_data, double sm_scale) { if (is_chain_on_depths_[0]) { // If the batch does not form a tree, use raggedness prefill kernel. - ICHECK_NOTNULL(f_attention_prefill_ragged_); + TVM_FFI_ICHECK_NOTNULL(f_attention_prefill_ragged_); f_attention_prefill_ragged_->MHA( q_data, k_data, v_data, cur_append_length_indptr_view_, cur_append_length_indptr_view_, q_rope_position_map_view_, k_ragged_rope_pos_offset_view_, /*causal=*/true, rope_mode_, rotary_scale_, rotary_theta_, sm_scale, o_data, lse_data, compute_stream_); } else { // The batch requires tree attention. - ICHECK(f_attention_prefill_with_tree_mask_ != nullptr) + TVM_FFI_ICHECK(f_attention_prefill_with_tree_mask_ != nullptr) << "Function \"f_attention_prefill_with_tree_mask_\" is not defined."; - ICHECK(tree_attn_mask_view_[0].defined()); - ICHECK(tree_attn_mn_indptr_view_[0].defined()); + TVM_FFI_ICHECK(tree_attn_mask_view_[0].defined()); + TVM_FFI_ICHECK(tree_attn_mn_indptr_view_[0].defined()); f_attention_prefill_with_tree_mask_->MHA( q_data, k_data, v_data, cur_append_length_indptr_view_, cur_append_length_indptr_view_, q_rope_position_map_view_, tree_attn_mn_indptr_view_[0], tree_attn_mask_view_[0], @@ -2122,9 +2138,9 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { void MLASelfAttnInternal(Tensor q_data, Tensor k_data, Tensor v_data, Tensor o_data, Tensor lse_data, double sm_scale) { - CHECK(is_chain_on_depths_[0]) << "Tree attn not able for MLA for now."; + TVM_FFI_ICHECK(is_chain_on_depths_[0]) << "Tree attn not able for MLA for now."; // If the batch does not form a tree, use raggedness prefill kernel. - ICHECK_NOTNULL(f_attention_prefill_ragged_); + TVM_FFI_ICHECK_NOTNULL(f_attention_prefill_ragged_); f_attention_prefill_ragged_->MHA( q_data, k_data, v_data, cur_append_length_indptr_view_, cur_append_length_indptr_view_, q_rope_position_map_view_, k_ragged_rope_pos_offset_view_, /*causal=*/true, RoPEMode::kNone, @@ -2144,7 +2160,8 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { attn_kinds_[local_layer_id + layer_id_begin_offset_] != AttnKind::kMHASliding) ? f_attention_decode_ : f_attention_decode_sliding_window_; - CHECK_GE(num_depths_, 1) << "The number of effective depths must be greater or equal to 1."; + TVM_FFI_ICHECK_GE(num_depths_, 1) + << "The number of effective depths must be greater or equal to 1."; bool cross_attn_computed = false; for (int d = 0; d < num_depths_; ++d) { @@ -2185,7 +2202,7 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { } if (append_before_attn_ && !is_chain_on_depths_[d]) { - ICHECK_NOTNULL(f_attention_prefill_with_tree_mask_paged_kv_); + TVM_FFI_ICHECK_NOTNULL(f_attention_prefill_with_tree_mask_paged_kv_); f_attention_prefill_with_tree_mask_paged_kv_->MHA( q_data, qo_indptr_on_depths_view_[d], pages_[local_layer_id], page_indptr, page_indices, length_info, k_rope_pos, q_rope_position_map_view_, tree_attn_mn_indptr_view_[d], @@ -2193,13 +2210,13 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { attn_lse, compute_stream_); } else if (use_decode_kernel_[d]) { // Use decode kernel for depth d - ICHECK_NOTNULL(f_decode); + TVM_FFI_ICHECK_NOTNULL(f_decode); f_decode->MHA(d, q_data, pages_[local_layer_id], page_indptr, page_indices, length_info, k_rope_pos, q_rope_position_map_view_, rope_mode_, rotary_scale, rotary_theta, sm_scale, attn_output, attn_lse, compute_stream_); } else { // Use prefill kernel for depth d - ICHECK_NOTNULL(f_prefill); + TVM_FFI_ICHECK_NOTNULL(f_prefill); f_prefill->MHA(d, q_data, qo_indptr_on_depths_view_[d], pages_[local_layer_id], page_indptr, page_indices, length_info, q_rope_position_map_view_, k_rope_pos, /*causal=*/false, @@ -2220,7 +2237,8 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { /*! \brief Compute cross-attention for MLA. Return if there is effective computation. */ bool MLACrossAttnInternal(int64_t local_layer_id, Tensor q_data, Tensor o_data, Tensor lse_data, double sm_scale) { - CHECK_GE(num_depths_, 1) << "The number of effective depths must be greater or equal to 1."; + TVM_FFI_ICHECK_GE(num_depths_, 1) + << "The number of effective depths must be greater or equal to 1."; bool is_first_kernel = true; for (int d = 0; d < num_depths_; ++d) { @@ -2236,8 +2254,8 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { attn_output = temp_attn_output_view_; attn_lse = temp_attn_lse_view_; } - CHECK(is_chain_on_depths_[d]) << "Tree attn not able for MLA for now."; - ICHECK_NOTNULL(f_mla_prefill_); + TVM_FFI_ICHECK(is_chain_on_depths_[d]) << "Tree attn not able for MLA for now."; + TVM_FFI_ICHECK_NOTNULL(f_mla_prefill_); f_mla_prefill_->MLA(d, q_data, qo_indptr_on_depths_view_[d], pages_[local_layer_id], page_indptr_on_depths_view_[d], page_indices_on_depths_view_[d], length_info_on_depths_view_[d], /*causal=*/false, sm_scale, attn_output, @@ -2277,7 +2295,7 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { * invoked before running attention computation on device. */ void SyncAuxArrayToDevice() { - ICHECK(dtype_aux_.bits == 32 && dtype_aux_.code == kDLInt); + TVM_FFI_ICHECK(dtype_aux_.bits == 32 && dtype_aux_.code == kDLInt); int64_t total_append_length = 0; int num_sequences = cur_append_lengths_.size(); cur_append_lengths_indptr_host_.clear(); @@ -2287,16 +2305,16 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { cur_append_lengths_[i]); } total_append_length = cur_append_lengths_indptr_host_.back(); - ICHECK_EQ(total_append_length, append_position_map_host_.size()); - ICHECK_EQ(total_append_length, kv_transfer_remote_position_map_host_.size()); - ICHECK_EQ(total_append_length, kv_transfer_recver_id_host_.size()); + TVM_FFI_ICHECK_EQ(total_append_length, append_position_map_host_.size()); + TVM_FFI_ICHECK_EQ(total_append_length, kv_transfer_remote_position_map_host_.size()); + TVM_FFI_ICHECK_EQ(total_append_length, kv_transfer_recver_id_host_.size()); // - Reset the copy. aux_data_manager_->ResetAttnAuxDataCopy(); // 1. q_rope_position_map // q_rope_position_map has to be synced first so that it has a 0 byte offset - ICHECK_EQ(q_rope_position_map_host_.size(), total_append_length); + TVM_FFI_ICHECK_EQ(q_rope_position_map_host_.size(), total_append_length); q_rope_position_map_view_ = aux_data_manager_->CopyQRoPEPosMapAsync(&q_rope_position_map_host_); // 2. qo_indptr_on_depths for (int d = 0; d < num_depths_; ++d) { @@ -2305,13 +2323,14 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { } // 3. page_indptr_on_depths for (int d = 0; d < num_depths_; ++d) { - ICHECK_EQ(page_indptr_on_depths_host_[d].size(), qo_indptr_on_depths_host_[d].size()); + TVM_FFI_ICHECK_EQ(page_indptr_on_depths_host_[d].size(), qo_indptr_on_depths_host_[d].size()); page_indptr_on_depths_view_[d] = aux_data_manager_->CopyPageIndptrOnDepthAsync(&page_indptr_on_depths_host_[d], d); } // 4. page_indices_on_depths for (int d = 0; d < num_depths_; ++d) { - ICHECK_EQ(page_indices_on_depths_host_[d].size(), page_indptr_on_depths_host_[d].back()); + TVM_FFI_ICHECK_EQ(page_indices_on_depths_host_[d].size(), + page_indptr_on_depths_host_[d].back()); page_indices_on_depths_view_[d] = aux_data_manager_->CopyPageIndicesOnDepthAsync(&page_indices_on_depths_host_[d], d); } @@ -2320,16 +2339,16 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { if (support_layer_sliding_window_) { // 5. page_indptr_sliding_window_on_depths for (int d = 0; d < num_depths_; ++d) { - ICHECK_EQ(page_indptr_sliding_window_on_depths_host_[d].size(), - qo_indptr_on_depths_host_[d].size()); + TVM_FFI_ICHECK_EQ(page_indptr_sliding_window_on_depths_host_[d].size(), + qo_indptr_on_depths_host_[d].size()); page_indptr_sliding_window_on_depths_view_[d] = aux_data_manager_->CopyPageIndptrOnDepthAsync( &page_indptr_sliding_window_on_depths_host_[d], d); } // 6. page_indices_sliding_window_on_depths for (int d = 0; d < num_depths_; ++d) { - ICHECK_EQ(page_indices_sliding_window_on_depths_host_[d].size(), - page_indptr_sliding_window_on_depths_host_[d].back()); + TVM_FFI_ICHECK_EQ(page_indices_sliding_window_on_depths_host_[d].size(), + page_indptr_sliding_window_on_depths_host_[d].back()); page_indices_sliding_window_on_depths_view_[d] = aux_data_manager_->CopyPageIndicesOnDepthAsync( &page_indices_sliding_window_on_depths_host_[d], d); @@ -2341,9 +2360,9 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { // sink_size_on_depths_host_; for (int d = 0; d < num_depths_; ++d) { int num_seq_on_layer = static_cast(qo_indptr_on_depths_host_[d].size()) - 1; - ICHECK_EQ(last_page_len_on_depths_host_[d].size(), num_seq_on_layer); - ICHECK_EQ(sliding_window_offset_on_depths_host_[d].size(), num_seq_on_layer); - ICHECK_EQ(sink_size_on_depths_host_[d].size(), num_seq_on_layer); + TVM_FFI_ICHECK_EQ(last_page_len_on_depths_host_[d].size(), num_seq_on_layer); + TVM_FFI_ICHECK_EQ(sliding_window_offset_on_depths_host_[d].size(), num_seq_on_layer); + TVM_FFI_ICHECK_EQ(sink_size_on_depths_host_[d].size(), num_seq_on_layer); if (!support_sliding_window_) { // Sliding window is not enabled, so we first copy "last_page_len". length_info_on_depths_view_[d] = @@ -2364,13 +2383,13 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { } // 6. k_rope_pos_offset_on_depths for (int d = 0; d < num_depths_; ++d) { - ICHECK_EQ(k_rope_pos_offset_on_depths_host_[d].size() + 1, - qo_indptr_on_depths_host_[d].size()); + TVM_FFI_ICHECK_EQ(k_rope_pos_offset_on_depths_host_[d].size() + 1, + qo_indptr_on_depths_host_[d].size()); k_rope_pos_offset_view_[d] = aux_data_manager_->CopyKRoPEPosOffsetOnDepthAsync( &k_rope_pos_offset_on_depths_host_[d], d); if (support_layer_sliding_window_) { - ICHECK_EQ(k_rope_pos_offset_sliding_window_on_depths_host_[d].size() + 1, - qo_indptr_on_depths_host_[d].size()); + TVM_FFI_ICHECK_EQ(k_rope_pos_offset_sliding_window_on_depths_host_[d].size() + 1, + qo_indptr_on_depths_host_[d].size()); k_rope_pos_offset_sliding_window_view_[d] = aux_data_manager_->CopyKRoPEPosOffsetOnDepthAsync( &k_rope_pos_offset_sliding_window_on_depths_host_[d], d); @@ -2380,7 +2399,7 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { cur_append_length_indptr_view_ = aux_data_manager_->CopyCurAppendLengthIndptrAsync(&cur_append_lengths_indptr_host_); // 8. k_ragged_rope_pos_offset - ICHECK_EQ(k_ragged_rope_pos_offset_host_.size(), num_sequences); + TVM_FFI_ICHECK_EQ(k_ragged_rope_pos_offset_host_.size(), num_sequences); k_ragged_rope_pos_offset_view_ = aux_data_manager_->CopyKRaggedRoPEPosOffsetAsync(&k_ragged_rope_pos_offset_host_); // 9. append_position_map @@ -2438,7 +2457,7 @@ TVM_FFI_STATIC_INIT_BLOCK() { refl::GlobalDef().def_packed( "vm.builtin.paged_attention_kv_cache_create", [](ffi::PackedArgs args, ffi::Any* rv) { // Todo: cuda graph arg - CHECK(args.size() == 28 || args.size() == 29) + TVM_FFI_ICHECK(args.size() == 28 || args.size() == 29) << "Invalid number of KV cache constructor args: " << args.size(); ffi::Shape cache_config = args[0].cast(); ffi::Shape layer_indptr_tuple = args[1].cast(); @@ -2449,7 +2468,7 @@ TVM_FFI_STATIC_INIT_BLOCK() { num_groups = disco_worker->num_groups; group_id = disco_worker->worker_id / (disco_worker->num_workers / num_groups); } - CHECK_EQ(layer_indptr_tuple.size(), num_groups + 1); + TVM_FFI_ICHECK_EQ(layer_indptr_tuple.size(), num_groups + 1); int64_t num_layers = layer_indptr_tuple[group_id + 1] - layer_indptr_tuple[group_id]; int64_t layer_id_begin_offset = layer_indptr_tuple[group_id]; int64_t layer_id_end_offset = layer_indptr_tuple[group_id + 1]; @@ -2499,7 +2518,7 @@ TVM_FFI_STATIC_INIT_BLOCK() { }; f_transpose_append_mha = f_convert_optional_packed_func(13); f_transpose_append_mla = f_convert_optional_packed_func(14); - CHECK(!f_merge_inplace.empty()) << "Merge inplace function is not defined."; + TVM_FFI_ICHECK(!f_merge_inplace.empty()) << "Merge inplace function is not defined."; std::vector attn_kinds_vec; attn_kinds_vec.reserve(attn_kinds.size()); @@ -2507,7 +2526,7 @@ TVM_FFI_STATIC_INIT_BLOCK() { attn_kinds_vec.push_back(static_cast(attn_kind)); } - CHECK_EQ(cache_config.size(), 5); + TVM_FFI_ICHECK_EQ(cache_config.size(), 5); int64_t reserved_num_seqs = cache_config[0]; int64_t total_token_capacity = cache_config[1]; int64_t prefill_chunk_size = cache_config[2]; diff --git a/src/runtime/vm/rnn_state.cc b/src/runtime/vm/rnn_state.cc index 61194b5dade2..a06ea41c5404 100644 --- a/src/runtime/vm/rnn_state.cc +++ b/src/runtime/vm/rnn_state.cc @@ -185,7 +185,7 @@ class RNNStateImpObj : public RNNStateObj { storages_.push_back(layer_storages); } - CHECK_GT(max_history_, 0) << "At least 1 history slot to store the current state"; + TVM_FFI_ICHECK_GT(max_history_, 0) << "At least 1 history slot to store the current state"; // Allocate the auxiliary arrays on device. seq_slot_ids_device_ = Tensor::Empty({reserved_num_seqs}, dtype_aux_, device); @@ -197,7 +197,7 @@ class RNNStateImpObj : public RNNStateObj { /*! \brief Reset the KV cache. */ void Clear() final { seq_map_.clear(); - ICHECK(!storages_.empty()); + TVM_FFI_ICHECK(!storages_.empty()); free_slot_ids_.clear(); for (int64_t slot_id = reserved_num_seqs_ - 1; slot_id >= 0; --slot_id) { free_slot_ids_.push_back(slot_id); @@ -209,7 +209,7 @@ class RNNStateImpObj : public RNNStateObj { void BeginForward(const ffi::Shape& seq_ids, const ffi::Shape& append_lengths, const ffi::Optional& opt_token_tree_parent_ptr) final { - CHECK_EQ(seq_ids.size(), append_lengths.size()) + TVM_FFI_ICHECK_EQ(seq_ids.size(), append_lengths.size()) << "The seq_ids size (" << seq_ids.size() << ") and append_lengths size (" << append_lengths.size() << ") mismatch."; @@ -218,7 +218,7 @@ class RNNStateImpObj : public RNNStateObj { int matched_pos = 0; for (int64_t append_length : append_lengths) { for (int64_t i = 0; i < append_length; ++i) { - CHECK_EQ(token_tree_parent_ptr[matched_pos], i - 1) + TVM_FFI_ICHECK_EQ(token_tree_parent_ptr[matched_pos], i - 1) << "Unexpected token tree for RNN state. RNN state only supports chains as token " "trees."; ++matched_pos; @@ -239,8 +239,8 @@ class RNNStateImpObj : public RNNStateObj { int64_t seq_id = cur_seq_ids_[i]; int64_t seq_length = cur_append_lengths_[i]; auto it = seq_map_.find(seq_id); - CHECK(it != seq_map_.end()) << "The sequence \"" << seq_id - << "\" cannot be found in the space state storage."; + TVM_FFI_ICHECK(it != seq_map_.end()) + << "The sequence \"" << seq_id << "\" cannot be found in the space state storage."; it->second.seq_length += seq_length; if (seq_length > 1) { // We cannot rollback the prefill input @@ -261,12 +261,12 @@ class RNNStateImpObj : public RNNStateObj { void Get(int64_t layer_id, int64_t state_id, Tensor o_data) final { // The auxiliary data structure on device must have been synchronized. - CHECK(!dirty_aux_data_device_) + TVM_FFI_ICHECK(!dirty_aux_data_device_) << "The auxiliary arrays are not synchronized to device. Please call " "`BeginForward` to synchronize before calling `Get`."; - ICHECK(cur_batch_size_ == static_cast(cur_seq_ids_.size())) + TVM_FFI_ICHECK(cur_batch_size_ == static_cast(cur_seq_ids_.size())) << "The batch size is not consistent with the number of sequence ids."; - CHECK_GT(cur_batch_size_, 0) << "The curent batch size should be greater than 0."; + TVM_FFI_ICHECK_GT(cur_batch_size_, 0) << "The curent batch size should be greater than 0."; // TODO(siyuan): support zero-copy when seq_len is one // Copy the state data to the return array. Tensor state = storages_[layer_id][state_id]; @@ -275,12 +275,12 @@ class RNNStateImpObj : public RNNStateObj { void Set(int64_t layer_id, int64_t state_id, Tensor data) final { // The auxiliary data structure on device must have been synchronized. - CHECK(!dirty_aux_data_device_) + TVM_FFI_ICHECK(!dirty_aux_data_device_) << "The auxiliary arrays are not synchronized to device. Please call " "`BeginForward` to synchronize before calling `Set`."; - ICHECK(cur_batch_size_ == static_cast(cur_seq_ids_.size())) + TVM_FFI_ICHECK(cur_batch_size_ == static_cast(cur_seq_ids_.size())) << "The batch size is not consistent with the number of sequence ids."; - CHECK_GT(cur_batch_size_, 0) << "The curent batch size should be greater than 0."; + TVM_FFI_ICHECK_GT(cur_batch_size_, 0) << "The curent batch size should be greater than 0."; Tensor state = storages_[layer_id][state_id]; f_sets_[state_id](state, seq_slot_ids_view_, history_slot_ids_view_, data); @@ -288,8 +288,8 @@ class RNNStateImpObj : public RNNStateObj { Tensor DebugGet(int64_t layer_id, int64_t state_id, int64_t seq_id) { auto it = seq_map_.find(seq_id); - CHECK(it != seq_map_.end()) << "The sequence \"" << seq_id - << "\" cannot be found in the space state storage."; + TVM_FFI_ICHECK(it != seq_map_.end()) + << "The sequence \"" << seq_id << "\" cannot be found in the space state storage."; Tensor state = storages_[layer_id][state_id]; int64_t seq_slot_id = it->second.seq_slot_id; int64_t history_slot_id = it->second.history_slot_id; @@ -306,7 +306,7 @@ class RNNStateImpObj : public RNNStateObj { /************** Sequence Management **************/ void AddSequence(int64_t seq_id) final { - CHECK(seq_map_.find(seq_id) == seq_map_.end()) + TVM_FFI_ICHECK(seq_map_.find(seq_id) == seq_map_.end()) << "The sequence \"" << seq_id << "\" is already in the space state storage."; int64_t seq_slot_id = GetFreeSlot(); seq_map_.insert({seq_id, Sequence(seq_slot_id)}); @@ -326,8 +326,8 @@ class RNNStateImpObj : public RNNStateObj { void RemoveSequence(int64_t seq_id) final { auto it = seq_map_.find(seq_id); - CHECK(it != seq_map_.end()) << "The sequence \"" << seq_id - << "\" cannot be found in the space state storage."; + TVM_FFI_ICHECK(it != seq_map_.end()) + << "The sequence \"" << seq_id << "\" cannot be found in the space state storage."; free_slot_ids_.push_back(it->second.seq_slot_id); seq_map_.erase(it); @@ -337,9 +337,9 @@ class RNNStateImpObj : public RNNStateObj { void ForkSequence(int64_t parent_seq_id, int64_t child_seq_id, int64_t fork_pos = -1) final { auto parent_it = seq_map_.find(parent_seq_id); - CHECK(parent_it != seq_map_.end()) << "The parent sequence \"" << parent_seq_id - << "\" cannot be found in space state storage."; - CHECK(seq_map_.find(child_seq_id) == seq_map_.end()) + TVM_FFI_ICHECK(parent_it != seq_map_.end()) << "The parent sequence \"" << parent_seq_id + << "\" cannot be found in space state storage."; + TVM_FFI_ICHECK(seq_map_.find(child_seq_id) == seq_map_.end()) << "The child sequence \"" << child_seq_id << "\" is already in the space state storage."; // Create a child block with the parent block pointer. @@ -360,10 +360,10 @@ class RNNStateImpObj : public RNNStateObj { void PopN(int64_t seq_id, int32_t n) final { auto it = seq_map_.find(seq_id); - CHECK(it != seq_map_.end()) << "The sequence \"" << seq_id - << "\" cannot be found in space state."; - CHECK_GE(n, 0) << "The length of rolling back " << n << " cannot be negative."; - CHECK_LE(n, it->second.available_history_num) + TVM_FFI_ICHECK(it != seq_map_.end()) + << "The sequence \"" << seq_id << "\" cannot be found in space state."; + TVM_FFI_ICHECK_GE(n, 0) << "The length of rolling back " << n << " cannot be negative."; + TVM_FFI_ICHECK_LE(n, it->second.available_history_num) << "The sequence only has " << it->second.available_history_num << " available history in the space state storage, while the length of rollback is " << n << " which exceeds the sequence length."; @@ -377,7 +377,8 @@ class RNNStateImpObj : public RNNStateObj { private: /*! \brief Get a new free block and return its index. */ int32_t GetFreeSlot() { - CHECK(!free_slot_ids_.empty()) << "The Sequence slot is full, cannot accept new sequence."; + TVM_FFI_ICHECK(!free_slot_ids_.empty()) + << "The Sequence slot is full, cannot accept new sequence."; int32_t seq_slot_id = free_slot_ids_.back(); free_slot_ids_.pop_back(); return seq_slot_id; @@ -441,8 +442,8 @@ class RNNStateImpObj : public RNNStateObj { history_slot_ids.reserve(cur_batch_size_); for (int64_t seq_id : cur_seq_ids_) { auto it = seq_map_.find(seq_id); - CHECK(it != seq_map_.end()) << "The sequence \"" << seq_id - << "\" cannot be found in the space state storage."; + TVM_FFI_ICHECK(it != seq_map_.end()) + << "The sequence \"" << seq_id << "\" cannot be found in the space state storage."; const Sequence& seq = it->second; seq_slot_ids.push_back(seq.seq_slot_id); history_slot_ids.push_back(seq.history_slot_id); @@ -473,21 +474,23 @@ TVM_FFI_STATIC_INIT_BLOCK() { ffi::Array f_gets, // ffi::Array f_sets, // ffi::Array init_layer_value) { - CHECK_GT(num_layers, 0) << "The number of layers should be greater than 0."; - CHECK_GT(reserved_num_seqs, 0) << "The number of reserved sequences should be greater than 0."; - CHECK_GE(max_history, 0) << "The maximum history length should be greater or equal than 0."; - CHECK_GT(init_layer_value.size(), 0) + TVM_FFI_ICHECK_GT(num_layers, 0) << "The number of layers should be greater than 0."; + TVM_FFI_ICHECK_GT(reserved_num_seqs, 0) + << "The number of reserved sequences should be greater than 0."; + TVM_FFI_ICHECK_GE(max_history, 0) + << "The maximum history length should be greater or equal than 0."; + TVM_FFI_ICHECK_GT(init_layer_value.size(), 0) << "The number of states per layer should be greater than 0."; Device device = init_layer_value[0]->device; for (const Tensor& state : init_layer_value) { - CHECK(state->device.device_type == device.device_type && - state->device.device_id == device.device_id) + TVM_FFI_ICHECK(state->device.device_type == device.device_type && + state->device.device_id == device.device_id) << "The device type of all states should be the same."; } - CHECK_EQ(f_gets.size(), init_layer_value.size()) + TVM_FFI_ICHECK_EQ(f_gets.size(), init_layer_value.size()) << "The number of state getters should be the same as the number of states per layer, " << "but got " << f_gets.size() << " and " << init_layer_value.size() << " respectively."; - CHECK_EQ(f_sets.size(), init_layer_value.size()) + TVM_FFI_ICHECK_EQ(f_sets.size(), init_layer_value.size()) << "The number of state setters should be the same as the number of states per layer, " << "but got " << f_sets.size() << " and " << init_layer_value.size() << " respectively."; ObjectPtr n = diff --git a/src/runtime/vm/tensor_cache_support.cc b/src/runtime/vm/tensor_cache_support.cc index 3bd249ccbafa..c115efbd41e0 100644 --- a/src/runtime/vm/tensor_cache_support.cc +++ b/src/runtime/vm/tensor_cache_support.cc @@ -100,10 +100,11 @@ TensorCacheMetadata TensorCacheMetadata::LoadFromStr(const std::string& json_str ffi::String err; json::Value json_info = json::Parse(json_str, &err); if (!err.empty()) { - LOG(FATAL) << "Failed to parse JSON: " << err << ". The JSON string is:" << json_str; + TVM_FFI_THROW(InternalError) << "Failed to parse JSON: " << err + << ". The JSON string is:" << json_str; } - CHECK(json_info.as()) - << "ValueError: The given string is not a JSON object: " << json_str; + TVM_FFI_CHECK(json_info.as(), ValueError) + << "The given string is not a JSON object: " << json_str; TensorCacheMetadata result = JSONAsTensorCacheMetadata(json_info.cast()); result.path = path; return result; @@ -115,10 +116,11 @@ TVM_DLL TensorCacheMetadata TensorCacheMetadata::Load(const std::string& path) { ffi::String err; json::Value json_info = json::Parse(json_str, &err); if (!err.empty()) { - LOG(FATAL) << "Failed to parse JSON: " << err << ". The JSON string is:" << json_str; + TVM_FFI_THROW(InternalError) << "Failed to parse JSON: " << err + << ". The JSON string is:" << json_str; } - CHECK(json_info.as()) - << "ValueError: The given string is not a JSON object: " << json_str; + TVM_FFI_CHECK(json_info.as(), ValueError) + << "The given string is not a JSON object: " << json_str; TensorCacheMetadata result = JSONAsTensorCacheMetadata(json_info.cast()); result.path = path; return result; @@ -173,9 +175,9 @@ TVM_DLL ffi::Array TensorCacheMetadata::FileRecord::Load( std::string* raw_data_buffer, // ffi::Optional* staging_buffer) const { LoadBinaryFromFile(path_prefix + "/" + this->data_path, raw_data_buffer); - CHECK_EQ(this->format, "raw-shard") << "ValueError: Only `raw-shard` format is supported"; - CHECK_EQ(this->nbytes, raw_data_buffer->length()) - << "ValueError: Encountered an corrupted parameter shard. It means it is not downloaded " + TVM_FFI_CHECK_EQ(this->format, "raw-shard", ValueError) << "Only `raw-shard` format is supported"; + TVM_FFI_CHECK_EQ(this->nbytes, raw_data_buffer->length(), ValueError) + << "Encountered an corrupted parameter shard. It means it is not downloaded " "completely or downloading is interrupted. Please try to download again."; ffi::Array result; result.reserve(this->records.size()); @@ -198,7 +200,8 @@ class TensorCache { static void Update(ffi::String name, Tensor arr, bool override) { TensorCache* pool = Global(); if (!override) { - ICHECK_EQ(pool->pool_.count(name), 0) << "Name " << name << " already exists in the cache"; + TVM_FFI_ICHECK_EQ(pool->pool_.count(name), 0) + << "Name " << name << " already exists in the cache"; } pool->pool_.Set(name, arr); } @@ -236,8 +239,8 @@ class TensorCache { try { params = shard_rec.Load(device, cache_path, &raw_data, &staging_buffer); } catch (const std::runtime_error& e) { - LOG(FATAL) << "ValueError: Error when loading parameters from " << shard_rec.data_path - << ": " << e.what(); + TVM_FFI_THROW(ValueError) << "Error when loading parameters from " << shard_rec.data_path + << ": " << e.what(); } int num_params = params.size(); for (int i = 0; i < num_params; ++i) { @@ -256,7 +259,7 @@ TVM_FFI_STATIC_INIT_BLOCK() { .def("vm.builtin.tensor_cache.get", TensorCache::Get) .def_packed("vm.builtin.tensor_cache.update", [](ffi::PackedArgs args, ffi::Any* rv) { - CHECK(args.size() == 2 || args.size() == 3); + TVM_FFI_ICHECK(args.size() == 2 || args.size() == 3); ffi::String name = args[0].cast(); bool is_override = args.size() == 2 ? false : args[2].cast(); @@ -307,7 +310,7 @@ class ParamModuleNode : public ffi::ModuleObj { params.push_back(opt.value()); } else { if (num_params == -1) return params; - LOG(FATAL) << "Cannot find " << name << " in cache"; + TVM_FFI_THROW(InternalError) << "Cannot find " << name << " in cache"; } } return params; @@ -320,7 +323,7 @@ class ParamModuleNode : public ffi::ModuleObj { if (ffi::Optional opt = TensorCache::Get(name)) { result.push_back(opt.value()); } else { - LOG(FATAL) << "ValueError: Cannot find parameter in cache: " << name; + TVM_FFI_THROW(ValueError) << "Cannot find parameter in cache: " << name; } } return result; @@ -355,8 +358,8 @@ TVM_FFI_STATIC_INIT_BLOCK() { names.reserve(args.size()); for (int i = 0; i < args.size(); ++i) { if (!args[i].try_cast()) { - LOG(FATAL) << "ValueError: Expect string as input, but get " - << args[i].GetTypeKey() << " at " << i; + TVM_FFI_THROW(ValueError) << "Expect string as input, but get " + << args[i].GetTypeKey() << " at " << i; } names.push_back(args[i].cast()); } diff --git a/src/runtime/vm/vm.cc b/src/runtime/vm/vm.cc index cd5475bf1ad5..38991d7714d7 100644 --- a/src/runtime/vm/vm.cc +++ b/src/runtime/vm/vm.cc @@ -71,13 +71,13 @@ ffi::Any IndexIntoNestedObject(ffi::Any obj, ffi::PackedArgs args, int starting_ for (int i = starting_arg_idx; i < args.size(); i++) { // the object must be an Array to be able to index into it if (!obj.as()) { - LOG(FATAL) << "ValueError: Attempted to index into an object that is not an Array."; + TVM_FFI_THROW(ValueError) << "Attempted to index into an object that is not an Array."; } int index = args[i].cast(); auto arr = obj.cast>(); // make sure the index is in bounds if (index >= static_cast(arr.size())) { - LOG(FATAL) << "IndexError: Invalid index (" << index << " >= " << arr.size() << ")."; + TVM_FFI_THROW(IndexError) << "Invalid index (" << index << " >= " << arr.size() << ")."; } obj = arr[index]; } @@ -326,7 +326,7 @@ class VirtualMachineImpl : public VirtualMachine { vm->frames_.emplace_back(std::move(frame)); } ~FrameGuard() { - ICHECK_GT(vm->frames_.size(), 0); + TVM_FFI_ICHECK_GT(vm->frames_.size(), 0); vm->pc_ = vm->frames_.back()->return_pc; vm->frames_.back()->Clear(); vm->frame_free_list_.emplace_back(std::move(vm->frames_.back())); @@ -360,7 +360,7 @@ class VirtualMachineImpl : public VirtualMachine { * \param obj The object to write to. */ TVM_ALWAYS_INLINE void WriteRegister(VMFrame* frame, RegName reg, const RegType& obj) { - ICHECK_LT(reg, frame->register_file.size()); + TVM_FFI_ICHECK_LT(reg, frame->register_file.size()); frame->register_file[reg] = obj; } /*! @@ -377,7 +377,7 @@ class VirtualMachineImpl : public VirtualMachine { if (reg == Instruction::kVoidRegister) { ret = nullptr; } else { - ICHECK_EQ(reg, Instruction::kVMRegister); + TVM_FFI_ICHECK_EQ(reg, Instruction::kVMRegister); // per convention, ctx ptr must be VirtualMachine* casted to void. // this and VirtualMachine* may or may not be the same // do first cast to VirtualMachine* then to void* @@ -461,7 +461,7 @@ void VirtualMachineImpl::LoadExecutable(ObjectPtr exec) { void VirtualMachineImpl::Init(const std::vector& devices, const std::vector& alloc_types) { - ICHECK_EQ(devices.size(), alloc_types.size()); + TVM_FFI_ICHECK_EQ(devices.size(), alloc_types.size()); this->devices.reserve(devices.size()); this->allocators.reserve(alloc_types.size()); @@ -485,17 +485,17 @@ void VirtualMachineImpl::Init(const std::vector& devices, } VMFuncInfo VirtualMachineImpl::LookupVMFuncInfo(const std::string& func_name) { - ICHECK(exec_) << "The executable is not created yet."; + TVM_FFI_ICHECK(exec_) << "The executable is not created yet."; auto it = this->exec_->func_map.find(func_name); - CHECK(it != this->exec_->func_map.end()) << "ValueError: Unknown function: " << func_name; + TVM_FFI_CHECK(it != this->exec_->func_map.end(), ValueError) << "Unknown function: " << func_name; return exec_->func_table[it->second]; } RegType VirtualMachineImpl::LookupVMOutput(const std::string& func_name) { if (!outputs_.count(func_name)) { - LOG(FATAL) << "ValueError: No output saved for call of \"" << func_name - << "\"; use `invoke_stateful` to call it first."; + TVM_FFI_THROW(ValueError) << "No output saved for call of \"" << func_name + << "\"; use `invoke_stateful` to call it first."; } return outputs_[func_name]; } @@ -507,7 +507,7 @@ void VirtualMachineImpl::SetInput(std::string func_name, bool with_param_module, Index gf_idx = m.at(func_name); const VMFuncInfo& vm_func = exec_->func_table[gf_idx]; size_t params_num = vm_func.num_args; - ICHECK_EQ(args.size(), params_num) + TVM_FFI_ICHECK_EQ(args.size(), params_num) << "The number of provided parameters doesn't match the number of arguments for"; std::vector func_args(params_num); for (int i = 0; i < args.size(); ++i) { @@ -520,7 +520,7 @@ void VirtualMachineImpl::SetInput(std::string func_name, bool with_param_module, } inputs_[func_name] = func_args; } else { - LOG(FATAL) << "ValueError: Unknown function: " << func_name; + TVM_FFI_THROW(ValueError) << "Unknown function: " << func_name; } } @@ -536,7 +536,7 @@ void VirtualMachineImpl::InvokeClosurePacked(const ObjectRef& closure_or_packedf } // run closure call. auto* clo = closure_or_packedfunc.as(); - ICHECK(clo != nullptr) << "Function expects a closure or ffi::Function "; + TVM_FFI_ICHECK(clo != nullptr) << "Function expects a closure or ffi::Function "; std::vector packed_args(args.size() + 1); // per convention, ctx ptr must be VirtualMachine* casted to void. @@ -570,7 +570,7 @@ RegType VirtualMachineImpl::InvokeClosureInternal(const ObjectRef& closure_or_pa if (packed != nullptr) { packed->CallPacked(packed_args.data(), packed_args.size(), &ret); } else { - ICHECK(clo != nullptr); + TVM_FFI_ICHECK(clo != nullptr); clo->impl.CallPacked(packed_args.data(), packed_args.size(), &ret); } return ret; @@ -603,7 +603,7 @@ ffi::Optional VirtualMachineImpl::GetClosureInternal(const ffi::Strin auto it = exec_->func_map.find(func_name); if (it == exec_->func_map.end()) { if (allow_missing) return std::nullopt; - LOG(FATAL) << "ValueError: Unknown function: " << func_name; + TVM_FFI_THROW(ValueError) << "Unknown function: " << func_name; } Index gf_idx = it->second; @@ -623,18 +623,18 @@ ffi::Optional VirtualMachineImpl::GetClosureInternal(const ffi::Strin }); return VMClosure(func_name, impl); } else { - ICHECK(finfo.kind == VMFuncInfo::FuncKind::kVMTIRFunc) + TVM_FFI_ICHECK(finfo.kind == VMFuncInfo::FuncKind::kVMTIRFunc) << "Cannot support closure with function kind " << static_cast(finfo.kind); ffi::Optional tir_func = GetFuncFromImports("__vmtir__" + finfo.name); - ICHECK(tir_func.has_value()) << "Cannot find underlying compiled tir function of VMTIRFunc " - << finfo.name; + TVM_FFI_ICHECK(tir_func.has_value()) + << "Cannot find underlying compiled tir function of VMTIRFunc " << finfo.name; auto impl = ffi::Function([this, finfo, tir_func](ffi::PackedArgs args, ffi::Any* rv) { // Per convention, ctx ptr is a VirtualMachine* VirtualMachine* ctx_ptr = static_cast(args[0].cast()); - ICHECK(ctx_ptr == this); - ICHECK_EQ(args.size() - 1, finfo.num_args) + TVM_FFI_ICHECK(ctx_ptr == this); + TVM_FFI_ICHECK_EQ(args.size() - 1, finfo.num_args) << "Function " << finfo.name << " expects " << finfo.num_args << " arguments"; - ICHECK_GE(finfo.register_file_size, finfo.num_args + 1); + TVM_FFI_ICHECK_GE(finfo.register_file_size, finfo.num_args + 1); std::vector reg_file(finfo.register_file_size); for (int64_t i = 0; i < finfo.num_args; ++i) { reg_file[i] = args[i + 1]; @@ -656,7 +656,7 @@ ffi::Optional VirtualMachineImpl::GetClosureInternal(const ffi::Strin //-------------------------------------------------------------------- RegType VirtualMachineImpl::InvokeBytecode(Index gf_idx, const std::vector& args) { const VMFuncInfo& gfunc = exec_->func_table[gf_idx]; - ICHECK(gfunc.kind == VMFuncInfo::FuncKind::kVMFunc); + TVM_FFI_ICHECK(gfunc.kind == VMFuncInfo::FuncKind::kVMFunc); // Get the curr instr which might be a potential caller. Instruction curr_instr = exec_->GetInstruction(pc_); @@ -668,9 +668,8 @@ RegType VirtualMachineImpl::InvokeBytecode(Index gf_idx, const std::vector(gfunc.num_args), args.size()) << "ValueError: Invoking function " - << gfunc.name << " expects " - << gfunc.num_args << " arguments" << + TVM_FFI_ICHECK_EQ(static_cast(gfunc.num_args), args.size()) + << "Invoking function " << gfunc.name << " expects " << gfunc.num_args << " arguments" << [&]() { std::stringstream ss; if (gfunc.param_names.size()) { @@ -684,7 +683,8 @@ RegType VirtualMachineImpl::InvokeBytecode(Index gf_idx, const std::vectorGetClosure(info.name); func_pool_[func_index] = clo; } @@ -749,19 +749,19 @@ void VirtualMachineImpl::RunInstrCall(VMFrame* curr_frame, Instruction instr) { break; } case Instruction::ArgKind::kFuncIdx: { - ICHECK_LT(static_cast(arg.value()), this->func_pool_.size()); + TVM_FFI_ICHECK_LT(static_cast(arg.value()), this->func_pool_.size()); call_args[arg_index] = this->func_pool_[arg.value()]; break; } default: { - LOG(FATAL) << "ValueError: Unknown argument kind: " << int(arg.kind()); + TVM_FFI_THROW(ValueError) << "Unknown argument kind: " << int(arg.kind()); } } } ffi::PackedArgs args(call_args.data() + args_begin_offset, instr.num_args); ffi::Any ret; - ICHECK_LT(static_cast(instr.func_idx), this->func_pool_.size()); + TVM_FFI_ICHECK_LT(static_cast(instr.func_idx), this->func_pool_.size()); if (instrument_ == nullptr) { this->InvokeClosurePacked(func_pool_[instr.func_idx].cast(), args, &ret); @@ -809,7 +809,8 @@ void VirtualMachineImpl::RunLoop() { VMFrame* curr_frame = frames_.back().get(); while (true) { - ICHECK_LT(static_cast(pc_), exec_->instr_offset.size()) << "run into invalid section"; + TVM_FFI_ICHECK_LT(static_cast(pc_), exec_->instr_offset.size()) + << "run into invalid section"; Instruction instr = exec_->GetInstruction(pc_); switch (instr.op) { case Opcode::Call: { @@ -841,7 +842,7 @@ void VirtualMachineImpl::RunLoop() { if (cond_val != 0) { pc_++; } else { - ICHECK_GT(instr.false_offset, 1); + TVM_FFI_ICHECK_GT(instr.false_offset, 1); pc_ += instr.false_offset; } break; @@ -859,7 +860,7 @@ ObjectPtr VirtualMachine::Create() { //-------------------------------------------------------------------- void VirtualMachineImpl::_Init(ffi::PackedArgs args, ffi::Any* rv) { - ICHECK_EQ(args.size() % 3, 0); + TVM_FFI_ICHECK_EQ(args.size() % 3, 0); std::vector devices; std::vector alloc_types; for (int i = 0; i < args.size(); i += 3) { @@ -873,7 +874,7 @@ void VirtualMachineImpl::_Init(ffi::PackedArgs args, ffi::Any* rv) { } void VirtualMachineImpl::_SaveClosure(ffi::PackedArgs args, ffi::Any* rv) { - ICHECK_GE(args.size(), 3); + TVM_FFI_ICHECK_GE(args.size(), 3); std::string func_name = args[0].cast(); this->SaveClosure(func_name, args[1].cast(), args[2].cast(), args.Slice(3)); } @@ -885,11 +886,11 @@ void VirtualMachineImpl::_InvokeClosure(ffi::PackedArgs args, ffi::Any* rv) { void VirtualMachineImpl::_InvokeClosureStateful(std::string func_name) { const std::unordered_map& m = this->exec_->func_map; if (m.find(func_name) == m.end()) { - LOG(FATAL) << "ValueError: Unknown function: " << func_name; + TVM_FFI_THROW(ValueError) << "Unknown function: " << func_name; } if (!inputs_.count(func_name)) { - LOG(FATAL) << "ValueError: No inputs set for stateful call of " << func_name - << "; use `set_input` first."; + TVM_FFI_THROW(ValueError) << "No inputs set for stateful call of " << func_name + << "; use `set_input` first."; return; } outputs_[func_name] = this->InvokeClosureInternal(func_pool_[m.at(func_name)].cast(), @@ -902,7 +903,7 @@ void VirtualMachineImpl::_SetInstrument(ffi::PackedArgs args, ffi::Any* rv) { } else { ffi::String func_name = args[0].cast(); const auto factory = tvm::ffi::Function::GetGlobal(func_name); - CHECK(factory.has_value()) << "Cannot find factory " << func_name; + TVM_FFI_ICHECK(factory.has_value()) << "Cannot find factory " << func_name; ffi::Any rv; factory->CallPacked(args.Slice(1), &rv); this->SetInstrument(rv.cast()); @@ -925,8 +926,8 @@ void VirtualMachineImpl::_GetOutput(ffi::PackedArgs args, ffi::Any* rv) { RegType out = LookupVMOutput(func_name); ffi::Any obj = IndexIntoNestedObject(out, args, 1); if (obj.as()) { - LOG(FATAL) << "ValueError: `get_output` cannot return a tuple for RPC compatibility. " - "Please specify another index argument."; + TVM_FFI_THROW(ValueError) << "`get_output` cannot return a tuple for RPC compatibility. " + "Please specify another index argument."; return; } *rv = obj; @@ -950,8 +951,8 @@ int VirtualMachineImpl::_GetFunctionArity(std::string func_name) { std::string VirtualMachineImpl::_GetFunctionParamName(std::string func_name, int index) { const VMFuncInfo& vm_func = LookupVMFuncInfo(func_name); if (static_cast(index) >= vm_func.param_names.size()) { - LOG(FATAL) << "ValueError: Invalid index for " << func_name << " (" << index << " out of " - << vm_func.param_names.size() << ")"; + TVM_FFI_THROW(ValueError) << "Invalid index for " << func_name << " (" << index << " out of " + << vm_func.param_names.size() << ")"; } return vm_func.param_names[index]; } @@ -961,7 +962,7 @@ ffi::Function VirtualMachineImpl::_LookupFunction(const ffi::String& name) { return ffi::Function([clo = opt.value(), _self = ffi::GetRef(this)]( ffi::PackedArgs args, ffi::Any* rv) -> void { auto* self = const_cast(_self.as()); - ICHECK(self); + TVM_FFI_ICHECK(self); self->InvokeClosurePacked(clo, args, rv); }); } @@ -999,12 +1000,12 @@ class VirtualMachineProfiler : public VirtualMachineImpl { bool clear_inputs = false; if (inputs.size() == 0) { - ICHECK(args.size() > 1) << "No input is provided"; + TVM_FFI_ICHECK(args.size() > 1) << "No input is provided"; SetInput(f_name, false, args.Slice(1)); inputs = GetInputsFor(f_name); clear_inputs = true; } else { - ICHECK_EQ(args.size(), 1) << "Inputs are already provided by set_input."; + TVM_FFI_ICHECK_EQ(args.size(), 1) << "Inputs are already provided by set_input."; } // warmup @@ -1085,7 +1086,7 @@ ObjectPtr VirtualMachine::CreateProfiler() { #else ObjectPtr VirtualMachine::CreateProfiler() { - LOG(FATAL) << "Profiler support is disabled"; + TVM_FFI_THROW(InternalError) << "Profiler support is disabled"; return nullptr; } #endif // TVM_VM_ENABLE_PROFILER diff --git a/src/runtime/vulkan/vulkan_common.cc b/src/runtime/vulkan/vulkan_common.cc index 30df8b86ecd5..22ad359abab5 100644 --- a/src/runtime/vulkan/vulkan_common.cc +++ b/src/runtime/vulkan/vulkan_common.cc @@ -38,7 +38,7 @@ std::vector FindEnabledExtensions( std::vector enabled_extensions; for (const auto& ext : required_extensions) { - ICHECK(available_extensions.count(ext)) + TVM_FFI_ICHECK(available_extensions.count(ext)) << "Required vulkan extension \"" << ext << "\" not supported by driver"; enabled_extensions.push_back(ext); } diff --git a/src/runtime/vulkan/vulkan_common.h b/src/runtime/vulkan/vulkan_common.h index fb4776c98afc..95744707ca2c 100644 --- a/src/runtime/vulkan/vulkan_common.h +++ b/src/runtime/vulkan/vulkan_common.h @@ -93,10 +93,10 @@ inline const char* VKGetErrorString(VkResult error) { * \brief Protected Vulkan call * \param func Expression to call. */ -#define VULKAN_CHECK_ERROR(__e) \ - { \ - ICHECK(__e == VK_SUCCESS) << "Vulkan Error, code=" << __e << ": " \ - << vulkan::VKGetErrorString(__e); \ +#define VULKAN_CHECK_ERROR(__e) \ + { \ + TVM_FFI_ICHECK(__e == VK_SUCCESS) \ + << "Vulkan Error, code=" << __e << ": " << vulkan::VKGetErrorString(__e); \ } #define VULKAN_CALL(func) \ diff --git a/src/runtime/vulkan/vulkan_device.cc b/src/runtime/vulkan/vulkan_device.cc index abdc0c0ce001..f1d3dd2c626e 100644 --- a/src/runtime/vulkan/vulkan_device.cc +++ b/src/runtime/vulkan/vulkan_device.cc @@ -90,12 +90,14 @@ VulkanDeviceProperties::VulkanDeviceProperties(const VulkanInstance& instance, if (instance.HasExtension("VK_KHR_get_physical_device_properties2")) { // Preferred method, call to get all properties that can be queried. - auto vkGetPhysicalDeviceProperties2KHR = (PFN_vkGetPhysicalDeviceProperties2KHR)ICHECK_NOTNULL( - vkGetInstanceProcAddr(instance, "vkGetPhysicalDeviceProperties2KHR")); + auto vkGetPhysicalDeviceProperties2KHR = + (PFN_vkGetPhysicalDeviceProperties2KHR)TVM_FFI_ICHECK_NOTNULL( + vkGetInstanceProcAddr(instance, "vkGetPhysicalDeviceProperties2KHR")); vkGetPhysicalDeviceProperties2KHR(device, &properties); - auto vkGetPhysicalDeviceFeatures2KHR = (PFN_vkGetPhysicalDeviceFeatures2KHR)ICHECK_NOTNULL( - vkGetInstanceProcAddr(instance, "vkGetPhysicalDeviceFeatures2KHR")); + auto vkGetPhysicalDeviceFeatures2KHR = + (PFN_vkGetPhysicalDeviceFeatures2KHR)TVM_FFI_ICHECK_NOTNULL( + vkGetInstanceProcAddr(instance, "vkGetPhysicalDeviceFeatures2KHR")); vkGetPhysicalDeviceFeatures2KHR(device, &features); } else { // Fallback, get as many features as we can from the Vulkan1.0 @@ -185,7 +187,8 @@ VulkanDeviceProperties::VulkanDeviceProperties(const VulkanInstance& instance, device_type = "cpu"; break; default: - LOG(FATAL) << "Unknown vulkan device type: " << properties.properties.deviceType; + TVM_FFI_THROW(InternalError) + << "Unknown vulkan device type: " << properties.properties.deviceType; break; } @@ -218,25 +221,29 @@ VulkanDeviceProperties::VulkanDeviceProperties(const VulkanInstance& instance, } VulkanDescriptorTemplateKHRFunctions::VulkanDescriptorTemplateKHRFunctions(VkDevice device) { - vkCreateDescriptorUpdateTemplateKHR = (PFN_vkCreateDescriptorUpdateTemplateKHR)ICHECK_NOTNULL( - vkGetDeviceProcAddr(device, "vkCreateDescriptorUpdateTemplateKHR")); - vkDestroyDescriptorUpdateTemplateKHR = (PFN_vkDestroyDescriptorUpdateTemplateKHR)ICHECK_NOTNULL( - vkGetDeviceProcAddr(device, "vkDestroyDescriptorUpdateTemplateKHR")); - vkUpdateDescriptorSetWithTemplateKHR = (PFN_vkUpdateDescriptorSetWithTemplateKHR)ICHECK_NOTNULL( - vkGetDeviceProcAddr(device, "vkUpdateDescriptorSetWithTemplateKHR")); - vkCmdPushDescriptorSetWithTemplateKHR = (PFN_vkCmdPushDescriptorSetWithTemplateKHR)ICHECK_NOTNULL( - vkGetDeviceProcAddr(device, "vkCmdPushDescriptorSetWithTemplateKHR")); + vkCreateDescriptorUpdateTemplateKHR = + (PFN_vkCreateDescriptorUpdateTemplateKHR)TVM_FFI_ICHECK_NOTNULL( + vkGetDeviceProcAddr(device, "vkCreateDescriptorUpdateTemplateKHR")); + vkDestroyDescriptorUpdateTemplateKHR = + (PFN_vkDestroyDescriptorUpdateTemplateKHR)TVM_FFI_ICHECK_NOTNULL( + vkGetDeviceProcAddr(device, "vkDestroyDescriptorUpdateTemplateKHR")); + vkUpdateDescriptorSetWithTemplateKHR = + (PFN_vkUpdateDescriptorSetWithTemplateKHR)TVM_FFI_ICHECK_NOTNULL( + vkGetDeviceProcAddr(device, "vkUpdateDescriptorSetWithTemplateKHR")); + vkCmdPushDescriptorSetWithTemplateKHR = + (PFN_vkCmdPushDescriptorSetWithTemplateKHR)TVM_FFI_ICHECK_NOTNULL( + vkGetDeviceProcAddr(device, "vkCmdPushDescriptorSetWithTemplateKHR")); } VulkanGetBufferMemoryRequirements2Functions::VulkanGetBufferMemoryRequirements2Functions( VkDevice device) { - vkGetBufferMemoryRequirements2KHR = (PFN_vkGetBufferMemoryRequirements2KHR)ICHECK_NOTNULL( + vkGetBufferMemoryRequirements2KHR = (PFN_vkGetBufferMemoryRequirements2KHR)TVM_FFI_ICHECK_NOTNULL( vkGetDeviceProcAddr(device, "vkGetBufferMemoryRequirements2KHR")); } VulkanQueueInsertDebugUtilsLabelFunctions::VulkanQueueInsertDebugUtilsLabelFunctions( VkInstance instance) { - vkQueueInsertDebugUtilsLabelEXT = (PFN_vkQueueInsertDebugUtilsLabelEXT)ICHECK_NOTNULL( + vkQueueInsertDebugUtilsLabelEXT = (PFN_vkQueueInsertDebugUtilsLabelEXT)TVM_FFI_ICHECK_NOTNULL( vkGetInstanceProcAddr(instance, "vkQueueInsertDebugUtilsLabelEXT")); } @@ -307,7 +314,7 @@ VulkanDevice::VulkanDevice(const VulkanInstance& instance, VkPhysicalDevice phy_ coherent_staging = ty.propertyFlags & VK_MEMORY_PROPERTY_HOST_COHERENT_BIT; } } - ICHECK_GE(win_rank, 0) << "Cannot find suitable staging memory on device."; + TVM_FFI_ICHECK_GE(win_rank, 0) << "Cannot find suitable staging memory on device."; win_rank = -1; for (uint32_t k = 0; k < prop.memoryTypeCount; ++k) { @@ -328,7 +335,7 @@ VulkanDevice::VulkanDevice(const VulkanInstance& instance, VkPhysicalDevice phy_ } } - ICHECK_GE(win_rank, 0) << "Cannot find suitable local memory on device."; + TVM_FFI_ICHECK_GE(win_rank, 0) << "Cannot find suitable local memory on device."; if (device_properties.supports_push_descriptor) { descriptor_template_khr_functions = @@ -574,8 +581,8 @@ void VulkanDevice::AllocateThreadLocalUniformBuffer(size_t min_size) { VulkanStagingBuffer& VulkanDevice::ThreadLocalUniformBuffer(size_t min_size) { VulkanStagingBuffer* buffer = uniform_buffer_per_thread.Get(); - ICHECK(buffer) << "Vulkan uniform buffer requested, but not previously allocated."; - ICHECK_GE(buffer->size, min_size) + TVM_FFI_ICHECK(buffer) << "Vulkan uniform buffer requested, but not previously allocated."; + TVM_FFI_ICHECK_GE(buffer->size, min_size) << "Vulkan uniform buffer of size " << min_size << " requested, but only " << buffer->size << " was previously allocated."; return *buffer; @@ -598,7 +605,7 @@ uint32_t FindMemoryType(const VulkanDevice& device, VkBufferCreateInfo info, } type_bits >>= 1; } - LOG(FATAL) << "Requested memory type not found"; + TVM_FFI_THROW(InternalError) << "Requested memory type not found"; return 0; } diff --git a/src/runtime/vulkan/vulkan_device_api.cc b/src/runtime/vulkan/vulkan_device_api.cc index a2ff8bb7ce0e..3dc5f146dd6a 100644 --- a/src/runtime/vulkan/vulkan_device_api.cc +++ b/src/runtime/vulkan/vulkan_device_api.cc @@ -79,10 +79,10 @@ VulkanDeviceAPI::VulkanDeviceAPI() { VulkanDeviceAPI::~VulkanDeviceAPI() {} void VulkanDeviceAPI::SetDevice(Device dev) { - ICHECK_EQ(dev.device_type, kDLVulkan) + TVM_FFI_ICHECK_EQ(dev.device_type, kDLVulkan) << "Active vulkan device cannot be set to non-vulkan device" << dev; - ICHECK_LE(dev.device_id, static_cast(devices_.size())) + TVM_FFI_ICHECK_LE(dev.device_id, static_cast(devices_.size())) << "Attempted to set active vulkan device to device_id==" << dev.device_id << ", but only " << devices_.size() << " devices present"; @@ -309,33 +309,33 @@ void* VulkanDeviceAPI::AllocWorkspace(Device dev, size_t size, DLDataType type_h void VulkanDeviceAPI::FreeWorkspace(Device dev, void* data) { auto* pool = pool_per_thread.Get(); - ICHECK(pool) << "Attempted to free a vulkan workspace on a CPU-thread " - << "that has never allocated a workspace"; + TVM_FFI_ICHECK(pool) << "Attempted to free a vulkan workspace on a CPU-thread " + << "that has never allocated a workspace"; pool->FreeWorkspace(dev, data); } TVMStreamHandle VulkanDeviceAPI::CreateStream(Device dev) { return nullptr; } void VulkanDeviceAPI::FreeStream(Device dev, TVMStreamHandle stream) { - ICHECK_EQ(stream, static_cast(nullptr)); + TVM_FFI_ICHECK_EQ(stream, static_cast(nullptr)); } // Syncing two streams is a nop, since there is only one stream. void VulkanDeviceAPI::SyncStreamFromTo(Device dev, TVMStreamHandle event_src, TVMStreamHandle event_dst) { - ICHECK_EQ(event_src, static_cast(nullptr)); - ICHECK_EQ(event_dst, static_cast(nullptr)); + TVM_FFI_ICHECK_EQ(event_src, static_cast(nullptr)); + TVM_FFI_ICHECK_EQ(event_dst, static_cast(nullptr)); } void VulkanDeviceAPI::StreamSync(Device dev, TVMStreamHandle stream) { - ICHECK_EQ(stream, static_cast(nullptr)); + TVM_FFI_ICHECK_EQ(stream, static_cast(nullptr)); device(dev.device_id).ThreadLocalStream().Synchronize(); } void VulkanDeviceAPI::CopyDataFromTo(const void* from, size_t from_offset, void* to, size_t to_offset, size_t size, Device dev_from, Device dev_to, DLDataType type_hint, TVMStreamHandle stream) { - ICHECK(stream == nullptr); + TVM_FFI_ICHECK(stream == nullptr); Device dev = dev_from; if (dev_from.device_type == kDLCPU) { dev = dev_to; @@ -344,7 +344,7 @@ void VulkanDeviceAPI::CopyDataFromTo(const void* from, size_t from_offset, void* int from_dev_type = static_cast(dev_from.device_type); int to_dev_type = static_cast(dev_to.device_type); if (from_dev_type == kDLVulkan && to_dev_type == kDLVulkan) { - ICHECK_EQ(dev_from.device_id, dev_to.device_id) + TVM_FFI_ICHECK_EQ(dev_from.device_id, dev_to.device_id) << "The Vulkan runtime does not support deviceA to deviceB copies. " << "This should be changed to a deviceA to CPU copy, followed by a CPU to deviceB copy"; @@ -436,14 +436,15 @@ void VulkanDeviceAPI::CopyDataFromTo(const void* from, size_t from_offset, void* // Stream? This would allow us to elide synchronizations here. stream.Synchronize(); } else { - LOG(FATAL) << "Expect copy from/to Vulkan or between Vulkan" - << ", from=" << from_dev_type << ", to=" << to_dev_type; + TVM_FFI_THROW(InternalError) << "Expect copy from/to Vulkan or between Vulkan" + << ", from=" << from_dev_type << ", to=" << to_dev_type; } } const VulkanDevice& VulkanDeviceAPI::device(size_t device_id) const { - ICHECK_LT(device_id, devices_.size()) << "Requested Vulkan device_id=" << device_id - << ", but only " << devices_.size() << " devices present"; + TVM_FFI_ICHECK_LT(device_id, devices_.size()) + << "Requested Vulkan device_id=" << device_id << ", but only " << devices_.size() + << " devices present"; return devices_[device_id]; } diff --git a/src/runtime/vulkan/vulkan_module.cc b/src/runtime/vulkan/vulkan_module.cc index 446ec11164e5..926711535123 100644 --- a/src/runtime/vulkan/vulkan_module.cc +++ b/src/runtime/vulkan/vulkan_module.cc @@ -48,7 +48,7 @@ ffi::Module VulkanModuleLoadFile(const std::string& file_name, const ffi::String support::BytesInStream stream(data); uint32_t magic; stream.Read(&magic); - ICHECK_EQ(magic, kVulkanModuleMagic) << "VulkanModule Magic mismatch"; + TVM_FFI_ICHECK_EQ(magic, kVulkanModuleMagic) << "VulkanModule Magic mismatch"; stream.Read(&smap); return VulkanModuleCreate(smap, fmap, ""); } @@ -60,7 +60,7 @@ ffi::Module VulkanModuleLoadFromBytes(const ffi::Bytes& bytes) { std::string fmt; stream.Read(&fmt); ffi::Map fmap; - ICHECK(stream.Read(&fmap)); + TVM_FFI_ICHECK(stream.Read(&fmap)); stream.Read(&smap); return VulkanModuleCreate(smap, fmap, ""); } diff --git a/src/runtime/vulkan/vulkan_stream.cc b/src/runtime/vulkan/vulkan_stream.cc index a84f4ae568b9..49ed530a6102 100644 --- a/src/runtime/vulkan/vulkan_stream.cc +++ b/src/runtime/vulkan/vulkan_stream.cc @@ -82,7 +82,7 @@ void VulkanStream::Launch(const std::function& kernel) void VulkanStream::LaunchDeferred(const std::function& deferred_initializer, const std::function& deferred_kernel, const VulkanStreamToken& deferred_token) { - ICHECK(!device_->UseImmediate()); + TVM_FFI_ICHECK(!device_->UseImmediate()); // If the new kernel uses the same descriptor set as one of the // kernels already in the command buffer, we need to synchronize @@ -90,7 +90,7 @@ void VulkanStream::LaunchDeferred(const std::function& deferred_initiali if (std::any_of(deferred_tokens_[deferred_token.descriptor_set_].begin(), deferred_tokens_[deferred_token.descriptor_set_].end(), [&](const VulkanStreamToken& token) { - DCHECK(token.descriptor_set_ == deferred_token.descriptor_set_); + TVM_FFI_DCHECK(token.descriptor_set_ == deferred_token.descriptor_set_); return token.descriptor_set_ == deferred_token.descriptor_set_ && token.buffers_ != deferred_token.buffers_; })) { @@ -105,7 +105,7 @@ void VulkanStream::LaunchDeferred(const std::function& deferred_initiali if (!std::any_of(deferred_tokens_[deferred_token.descriptor_set_].begin(), deferred_tokens_[deferred_token.descriptor_set_].end(), [&](const VulkanStreamToken& token) { - DCHECK(token.descriptor_set_ == deferred_token.descriptor_set_); + TVM_FFI_DCHECK(token.descriptor_set_ == deferred_token.descriptor_set_); return token.descriptor_set_ == deferred_token.descriptor_set_ && token.buffers_ == deferred_token.buffers_; })) { @@ -125,8 +125,8 @@ void VulkanStream::Synchronize() { deferred_kernels_.clear(); deferred_tokens_.clear(); } else { - DCHECK_EQ(deferred_kernels_.size(), 0); - DCHECK_EQ(deferred_tokens_.size(), 0); + TVM_FFI_DCHECK_EQ(deferred_kernels_.size(), 0); + TVM_FFI_DCHECK_EQ(deferred_tokens_.size(), 0); } VULKAN_CALL(vkEndCommandBuffer(state_->cmd_buffer_)); diff --git a/src/runtime/vulkan/vulkan_wrapped_func.cc b/src/runtime/vulkan/vulkan_wrapped_func.cc index da78fdf4ba2e..952e546fdd48 100644 --- a/src/runtime/vulkan/vulkan_wrapped_func.cc +++ b/src/runtime/vulkan/vulkan_wrapped_func.cc @@ -75,7 +75,7 @@ void VulkanWrappedFunc::operator()(ffi::PackedArgs args, ffi::Any* rv, // Can safely capture by reference as this lambda is immediately executed on the calling thread. device.ThreadLocalStream().Launch([&](VulkanStreamState* state) { vkCmdBindPipeline(state->cmd_buffer_, VK_PIPELINE_BIND_POINT_COMPUTE, pipeline->pipeline); - ICHECK(pipeline->descriptor_update_template != VK_NULL_HANDLE); + TVM_FFI_ICHECK(pipeline->descriptor_update_template != VK_NULL_HANDLE); device.descriptor_template_khr_functions->vkCmdPushDescriptorSetWithTemplateKHR( state->cmd_buffer_, pipeline->descriptor_update_template, pipeline->pipeline_layout, 0, descriptor_buffers.data()); @@ -190,7 +190,7 @@ VulkanModuleNode::~VulkanModuleNode() { for (size_t device_id = 0; device_id < ecache_.size(); ++device_id) { for (auto& kv : ecache_[device_id]) { auto& pe = kv.second; - ICHECK(pe); + TVM_FFI_ICHECK(pe); const auto& device = VulkanDeviceAPI::Global()->device(device_id); if (pe->descriptor_update_template != VK_NULL_HANDLE) { @@ -208,7 +208,7 @@ VulkanModuleNode::~VulkanModuleNode() { ffi::Optional VulkanModuleNode::GetFunction(const ffi::String& name) { ObjectPtr sptr_to_self = ffi::GetObjectPtr(this); - ICHECK_EQ(sptr_to_self.get(), this); + TVM_FFI_ICHECK_EQ(sptr_to_self.get(), this); auto opt_info = fmap_.Get(name); if (!opt_info.has_value()) return std::nullopt; FunctionInfo info = opt_info.value(); @@ -233,7 +233,7 @@ std::shared_ptr VulkanModuleNode::GetPipeline(size_t device_id, { // create shader auto sit = smap_.find(func_name); - ICHECK(sit != smap_.end()); + TVM_FFI_ICHECK(sit != smap_.end()); pe->use_ubo = sit->second.flag & (1 << ShaderMetaDataFlagMask::kUseUBO); const std::vector& data = sit->second.data; VkShaderModuleCreateInfo shader_cinfo; @@ -287,7 +287,7 @@ std::shared_ptr VulkanModuleNode::GetPipeline(size_t device_id, { auto opt_info = fmap_.Get(func_name); - ICHECK(opt_info.has_value()); + TVM_FFI_ICHECK(opt_info.has_value()); FunctionInfo finfo = opt_info.value(); for (DLDataType arg_type : finfo->arg_types) { if (arg_type.code == kDLOpaqueHandle) { @@ -355,7 +355,7 @@ std::shared_ptr VulkanModuleNode::GetPipeline(size_t device_id, if (0 < nbytes_scalars && !pe->use_ubo) { playout_cinfo.pushConstantRangeCount = 1; playout_cinfo.pPushConstantRanges = &crange; - ICHECK_LE(crange.size, device.device_properties.max_push_constants_size) + TVM_FFI_ICHECK_LE(crange.size, device.device_properties.max_push_constants_size) << "The Vulkan shader uses " << crange.size << " bytes of push constants, but the device only supports " << device.device_properties.max_push_constants_size << "bytes. " @@ -407,7 +407,7 @@ std::shared_ptr VulkanModuleNode::GetPipeline(size_t device_id, void VulkanModuleNode::WriteToFile(const ffi::String& file_name, const ffi::String& format) const { std::string fmt = GetFileFormat(file_name, format); - ICHECK_EQ(fmt, fmt_) << "Can only save to customized format vulkan"; + TVM_FFI_ICHECK_EQ(fmt, fmt_) << "Can only save to customized format vulkan"; std::string meta_file = GetMetaFilePath(file_name); SaveMetaDataToFile(meta_file, fmap_); std::string result; diff --git a/src/runtime/workspace_pool.cc b/src/runtime/workspace_pool.cc index 6ed5bf4daba6..a5ed75e301fd 100644 --- a/src/runtime/workspace_pool.cc +++ b/src/runtime/workspace_pool.cc @@ -95,7 +95,7 @@ class WorkspacePool::Pool { int index = static_cast(allocated_.size()) - 2; for (; index > 0 && allocated_[index].data != data; --index) { } - ICHECK_GT(index, 0) << "trying to free things that has not been allocated"; + TVM_FFI_ICHECK_GT(index, 0) << "trying to free things that has not been allocated"; e = allocated_[index]; allocated_.erase(allocated_.begin() + index); } @@ -159,7 +159,8 @@ void* WorkspacePool::AllocWorkspace(Device dev, size_t size) { } void WorkspacePool::FreeWorkspace(Device dev, void* ptr) { - ICHECK(static_cast(dev.device_id) < array_.size() && array_[dev.device_id] != nullptr); + TVM_FFI_ICHECK(static_cast(dev.device_id) < array_.size() && + array_[dev.device_id] != nullptr); array_[dev.device_id]->Free(ptr); } diff --git a/src/s_tir/analysis/calculate_allocated_memory.cc b/src/s_tir/analysis/calculate_allocated_memory.cc index 8580d510a271..3158d163f28f 100644 --- a/src/s_tir/analysis/calculate_allocated_memory.cc +++ b/src/s_tir/analysis/calculate_allocated_memory.cc @@ -63,7 +63,7 @@ tvm::ffi::Map AllocationCalculator::operator()(const Pr std::string GetStorageScope(const Var& var) { auto* ptr = var->type_annotation.as(); - ICHECK(ptr) << "Buffer Var's type annotation must be of PointerType"; + TVM_FFI_ICHECK(ptr) << "Buffer Var's type annotation must be of PointerType"; return ptr->storage_scope; } @@ -111,8 +111,9 @@ TVM_FFI_STATIC_INIT_BLOCK() { } else if (auto mod = obj.as()) { return CalculateAllocatedBytes(mod.value()); } else { - LOG(FATAL) << "TypeError: Expect the input to be either PrimFunc or IRModule, but gets: " - << obj->GetTypeKey(); + TVM_FFI_THROW(TypeError) + << "Expect the input to be either PrimFunc or IRModule, but gets: " + << obj->GetTypeKey(); throw; } }); @@ -189,10 +190,11 @@ Pass VerifyVTCMLimit(ffi::Optional default_target) { auto sizes = CalculateAllocatedBytes(func)["main"]; const auto vtcm_allocated = sizes.Get("global.vtcm").value_or(0); if (vtcm_allocated.IntValue() > limit.value()) { - LOG(FATAL) << "RuntimeError: The global.vtcm memory allocation limit has been exceeded " - << "(allocated: " << vtcm_allocated << ", limit: " << limit.value() << ").\n" - << "In function\n" - << func; + TVM_FFI_THROW(RuntimeError) + << "The global.vtcm memory allocation limit has been exceeded " + << "(allocated: " << vtcm_allocated << ", limit: " << limit.value() << ").\n" + << "In function\n" + << func; } } } diff --git a/src/s_tir/analysis/estimate_flops.cc b/src/s_tir/analysis/estimate_flops.cc index b414a673d500..b11d262281a5 100644 --- a/src/s_tir/analysis/estimate_flops.cc +++ b/src/s_tir/analysis/estimate_flops.cc @@ -265,8 +265,8 @@ TVM_FFI_STATIC_INIT_BLOCK() { } else if (auto stmt = obj.as()) { return EstimateTIRFlops(stmt.value()); } else { - LOG(FATAL) << "TypeError: Expect the input to be either IRModule or Stmt, but gets: " - << obj->GetTypeKey(); + TVM_FFI_THROW(TypeError) << "Expect the input to be either IRModule or Stmt, but gets: " + << obj->GetTypeKey(); throw; } }); diff --git a/src/s_tir/analysis/find_anchor_sblock.cc b/src/s_tir/analysis/find_anchor_sblock.cc index ce39cf7255f6..ad04aaf284f3 100644 --- a/src/s_tir/analysis/find_anchor_sblock.cc +++ b/src/s_tir/analysis/find_anchor_sblock.cc @@ -52,7 +52,7 @@ Stmt GetEnclosingLoop(const SBlockNode* block, Stmt func_body) { GetRootSeqStmt seq_finder; seq_finder(func_body); - ICHECK(seq_finder.result); + TVM_FFI_ICHECK(seq_finder.result); for (auto stmt : seq_finder.result->seq) { if (stmt->IsInstance()) { @@ -64,7 +64,8 @@ Stmt GetEnclosingLoop(const SBlockNode* block, Stmt func_body) { } } - LOG(FATAL) << "Enclosing loop not found for a block " << ffi::GetRef(block); + TVM_FFI_THROW(InternalError) << "Enclosing loop not found for a block " + << ffi::GetRef(block); TVM_FFI_UNREACHABLE(); } diff --git a/src/s_tir/analysis/identify_memcpy.cc b/src/s_tir/analysis/identify_memcpy.cc index 91cf9e900549..c7edf89c06c8 100644 --- a/src/s_tir/analysis/identify_memcpy.cc +++ b/src/s_tir/analysis/identify_memcpy.cc @@ -303,7 +303,7 @@ TVM_FFI_STATIC_INIT_BLOCK() { } else if (auto* ptr = std::get_if(&result)) { output->push_back(StringImm(*ptr)); } else { - LOG(FATAL) << "Internal error, unhandled std::variant type"; + TVM_FFI_THROW(InternalError) << "Internal error, unhandled std::variant type"; } IRVisitorWithAnalyzer::VisitStmt_(op); diff --git a/src/s_tir/analysis/is_pure_function.cc b/src/s_tir/analysis/is_pure_function.cc index 77d8b263b634..ac0cfa900066 100644 --- a/src/s_tir/analysis/is_pure_function.cc +++ b/src/s_tir/analysis/is_pure_function.cc @@ -56,10 +56,11 @@ class PurityChecker : TIRVisitorWithPath { if (!internal_allocations_.count(op->buffer->data)) { is_pure_ = false; - LOG_IF(FATAL, assert_on_error_) << "AssertionError: " - << "Pure functions must not write to buffers, " + if (assert_on_error_) { + TVM_FFI_THROW(AssertionError) << "Pure functions must not write to buffers, " << ", but function contains store to " << op->buffer << op->indices << " of value " << op->value; + } } } @@ -77,11 +78,12 @@ class PurityChecker : TIRVisitorWithPath { if (effect == CallEffectKind::kUpdateState || effect == CallEffectKind::kOpaque) { is_pure_ = false; - LOG_IF(FATAL, assert_on_error_) - << "AssertionError: " - << "Pure functions must not contain calls to impure operators, " - << "but " << ffi::GetRef(call) << " calls operator " << call->op - << ", which has side effect " << effect; + if (assert_on_error_) { + TVM_FFI_THROW(AssertionError) + << "Pure functions must not contain calls to impure operators, " + << "but " << ffi::GetRef(call) << " calls operator " << call->op + << ", which has side effect " << effect; + } } } diff --git a/src/s_tir/analysis/oob_checker.cc b/src/s_tir/analysis/oob_checker.cc index 4fa4c55942c2..18cb2418e497 100644 --- a/src/s_tir/analysis/oob_checker.cc +++ b/src/s_tir/analysis/oob_checker.cc @@ -118,7 +118,8 @@ tvm::transform::Pass OOBChecker() { if (checker.errors.size() > 0) { // mod doesn't contain our function, so we construct a new mod with out function IRModule func_mod({{GlobalVar("main"), func}}); - LOG(FATAL) << OOBError(func_mod, checker.errors).RenderReport("Out of bounds checker"); + TVM_FFI_THROW(ScheduleError) + << OOBError(func_mod, checker.errors).RenderReport("Out of bounds checker"); } return func; }; diff --git a/src/s_tir/analysis/sblock_access_region_detector.cc b/src/s_tir/analysis/sblock_access_region_detector.cc index 2a64d531c856..22c0ed5ad920 100644 --- a/src/s_tir/analysis/sblock_access_region_detector.cc +++ b/src/s_tir/analysis/sblock_access_region_detector.cc @@ -125,7 +125,8 @@ class BlockReadWriteDetector : public StmtExprVisitor { void BlockReadWriteDetector::operator()(const Stmt& stmt) { const auto* block = stmt.as(); - ICHECK(block != nullptr) << "Only visiting Blocks is allowed, but got " << stmt->GetTypeKey(); + TVM_FFI_ICHECK(block != nullptr) + << "Only visiting Blocks is allowed, but got " << stmt->GetTypeKey(); for (const MatchBufferRegion& match_buffer : block->match_buffers) { const Var& target_var = match_buffer->buffer->data; const Var& source_var = match_buffer->source->buffer->data; @@ -288,7 +289,7 @@ std::vector BlockReadWriteDetector::ConvertMatchedRegion( Region region; region.reserve(int_sets.size()); - ICHECK_EQ(buffer->shape.size(), int_sets.size()); + TVM_FFI_ICHECK_EQ(buffer->shape.size(), int_sets.size()); for (size_t i = 0; i < int_sets.size(); ++i) { const tvm::arith::IntSet& int_set = int_sets[i]; region.push_back(int_set.CoverRange(Range::FromMinExtent(0, buffer->shape[i]))); @@ -315,11 +316,11 @@ void BlockReadWriteDetector::Update(std::vector* buffers, buffer = match_buffer->source->buffer; region = ConvertMatchedRegion(match_buffer, std::move(region)); } - ICHECK_EQ(buffers->size(), regions->size()) + TVM_FFI_ICHECK_EQ(buffers->size(), regions->size()) << " Expected the buffer and regions to have the same size "; for (size_t i = 0; i < regions->size(); ++i) { if ((*buffers)[i].same_as(buffer)) { - ICHECK_EQ((*regions)[i].size(), region.size()) << "Inconsistent buffer dimension"; + TVM_FFI_ICHECK_EQ((*regions)[i].size(), region.size()) << "Inconsistent buffer dimension"; for (size_t j = 0; j < region.size(); ++j) { (*regions)[i][j] = arith::Union({(*regions)[i][j], region[j]}); } @@ -333,7 +334,7 @@ void BlockReadWriteDetector::Update(std::vector* buffers, ffi::Array BlockReadWriteDetector::CollectRegions( const std::vector& buffers, const std::vector>& regions, const std::unordered_set* excluded_buffers) { - ICHECK_EQ(buffers.size(), regions.size()); + TVM_FFI_ICHECK_EQ(buffers.size(), regions.size()); ffi::Array res; res.reserve(buffers.size()); for (size_t i = 0; i < regions.size(); ++i) { @@ -342,7 +343,7 @@ ffi::Array BlockReadWriteDetector::CollectRegions( } ffi::Array region; region.reserve(regions[i].size()); - ICHECK_EQ(buffers[i]->shape.size(), regions[i].size()); + TVM_FFI_ICHECK_EQ(buffers[i]->shape.size(), regions[i].size()); for (size_t j = 0; j < regions[i].size(); j++) { const tvm::arith::IntSet& range = regions[i][j]; if (range.CanProveSinglePoint(&ana_)) { diff --git a/src/s_tir/analysis/sblock_buffer_access_lca_detector.cc b/src/s_tir/analysis/sblock_buffer_access_lca_detector.cc index da0dab4a97b4..9726f58a738b 100644 --- a/src/s_tir/analysis/sblock_buffer_access_lca_detector.cc +++ b/src/s_tir/analysis/sblock_buffer_access_lca_detector.cc @@ -242,7 +242,7 @@ class LCADetector : public StmtExprVisitor { void VisitStmt_(const AttrStmtNode* op) final { if (op->attr_key == attr::thread_extent) { const auto* iter = op->node.as(); - ICHECK_NOTNULL(iter); + TVM_FFI_ICHECK_NOTNULL(iter); const runtime::ThreadScope& scope = runtime::ThreadScope::Create(iter->thread_tag); if (scope.rank == 0) { blockidx_scopes_.push_back(ancestor_scopes_.back()); @@ -314,7 +314,7 @@ class LCADetector : public StmtExprVisitor { if (rhs->parent_scope_info == nullptr) { return rhs; } - ICHECK(lhs == rhs); + TVM_FFI_ICHECK(lhs == rhs); return lhs; } diff --git a/src/s_tir/analysis/verify_gpu_code.cc b/src/s_tir/analysis/verify_gpu_code.cc index e9b29bef8da5..ed7853a99c6a 100644 --- a/src/s_tir/analysis/verify_gpu_code.cc +++ b/src/s_tir/analysis/verify_gpu_code.cc @@ -95,7 +95,7 @@ class GPUCodeVerifier : public StmtExprVisitor { Var var = op->node.as()->var; const auto* extent = op->value.as(); - ICHECK(extent); + TVM_FFI_ICHECK(extent); std::string name = var.get()->name_hint; // record the number of threads in a block @@ -175,7 +175,7 @@ class GPUCodeVerifier : public StmtExprVisitor { void VisitStmt_(const ForNode* op) { if (op->loop_var->name_hint == "vthread.s") { const auto* extent = op->extent.as(); - ICHECK(extent); + TVM_FFI_ICHECK(extent); size_t num_vthread = static_cast(extent->value); if (num_vthread > max_vthread_) { @@ -311,7 +311,7 @@ std::vector VerifyGPUCode_(const PrimFunc& func, } else if (iter.first == "max_kernels") { max_kernels = val->value; } else { - LOG(FATAL) << "Invalid check item: " << iter.first; + TVM_FFI_THROW(InternalError) << "Invalid check item: " << iter.first; } } @@ -342,9 +342,9 @@ Pass VerifyGPUCode(ffi::Map constraints) { for (auto& err : errs) { s << " " << err << std::endl; } - LOG(FATAL) << "RuntimeError: GPU constraint(s) violated:\n" - << s.str() << " In function\n" - << func.value(); + TVM_FFI_THROW(RuntimeError) << "GPU constraint(s) violated:\n" + << s.str() << " In function\n" + << func.value(); } } } diff --git a/src/s_tir/backend/adreno/inject_texture_alloc.cc b/src/s_tir/backend/adreno/inject_texture_alloc.cc index 7a4a81964689..f9cf0865b074 100644 --- a/src/s_tir/backend/adreno/inject_texture_alloc.cc +++ b/src/s_tir/backend/adreno/inject_texture_alloc.cc @@ -66,11 +66,12 @@ class TextureAllocInjector : public arith::IRMutatorWithAnalyzer { std::string storage_scope = GetStorageScope(op->buffer_var); if (IsTextureStorage(storage_scope)) { op = stmt.as(); - ICHECK(op->extents.size() >= 3) << "Only 2D Array RGBA texture is currently supported"; + TVM_FFI_ICHECK(op->extents.size() >= 3) + << "Only 2D Array RGBA texture is currently supported"; const int data_bits = op->dtype.bits(), vec_length = static_cast(op->extents.back().as()->value); const int channel_size = data_bits * vec_length; - ICHECK(channel_size == 128 || channel_size == 64) + TVM_FFI_ICHECK(channel_size == 128 || channel_size == 64) << "Invalid Channel Size: " << channel_size << " bits"; size_t axis = DefaultTextureLayoutSeparator(op->extents.size(), storage_scope); @@ -91,7 +92,7 @@ class TextureAllocInjector : public arith::IRMutatorWithAnalyzer { protected: std::string GetStorageScope(const Var& buffer_var) { auto* ptr = buffer_var->type_annotation.as(); - ICHECK(ptr) << "Buffer Var's type annotation must be of PointerType"; + TVM_FFI_ICHECK(ptr) << "Buffer Var's type annotation must be of PointerType"; return ptr->storage_scope; } }; diff --git a/src/s_tir/backend/adreno/texture_flatten.cc b/src/s_tir/backend/adreno/texture_flatten.cc index 691bb407a010..26786d303b37 100644 --- a/src/s_tir/backend/adreno/texture_flatten.cc +++ b/src/s_tir/backend/adreno/texture_flatten.cc @@ -59,7 +59,7 @@ class TextureLoweringBase : public StmtExprMutator { inline PrimExpr SimplifyOffset(const ffi::Array& shape, const ffi::Array& index) const { PrimExpr base = make_const(DataType::Int(32), 0); - ICHECK_EQ(shape.size(), index.size()); + TVM_FFI_ICHECK_EQ(shape.size(), index.size()); if (index.size() > 0) { PrimExpr offset = index[0]; for (size_t i = 1; i < index.size(); ++i) { @@ -73,7 +73,7 @@ class TextureLoweringBase : public StmtExprMutator { protected: std::string GetStorageScope(const Buffer& buffer) { auto* ptr = buffer->data->type_annotation.as(); - ICHECK(ptr) << "Buffer Var's type annotation must be of PointerType"; + TVM_FFI_ICHECK(ptr) << "Buffer Var's type annotation must be of PointerType"; return ptr->storage_scope; } diff --git a/src/s_tir/data_layout.cc b/src/s_tir/data_layout.cc index 9e028d20429b..267368975dfd 100644 --- a/src/s_tir/data_layout.cc +++ b/src/s_tir/data_layout.cc @@ -57,7 +57,7 @@ const LayoutAxis LayoutAxis::LOWER_CASE[] = { LayoutAxis('z')}; const LayoutAxis& LayoutAxis::Get(const char name) { - ICHECK((name >= 'A' && name <= 'Z') || (name >= 'a' && name <= 'z')) + TVM_FFI_ICHECK((name >= 'A' && name <= 'Z') || (name >= 'a' && name <= 'z')) << "Invalid layout axis name: " << name << ". Has to be A-Z or a-z."; return (name >= 'A' && name <= 'Z') ? LayoutAxis::UPPER_CASE[name - 'A'] : LayoutAxis::LOWER_CASE[name - 'a']; @@ -65,12 +65,12 @@ const LayoutAxis& LayoutAxis::Get(const char name) { const LayoutAxis& LayoutAxis::Get(const IterVar& itvar) { const std::string axis = itvar->var.get()->name_hint; - ICHECK_EQ(axis.size(), 1) << "Invalid layout axis " << axis; + TVM_FFI_ICHECK_EQ(axis.size(), 1) << "Invalid layout axis " << axis; return LayoutAxis::Get(axis[0]); } const LayoutAxis& LayoutAxis::Get(const std::string& name) { - ICHECK_EQ(name.length(), 1) << "Invalid axis " << name; + TVM_FFI_ICHECK_EQ(name.length(), 1) << "Invalid axis " << name; return LayoutAxis::Get(name[0]); } @@ -80,13 +80,13 @@ Layout::Layout(const ffi::Array& axes) { std::ostringstream repr; for (const IterVar& axis : axes) { if (const auto* factor = axis->dom->extent.as()) { - ICHECK_GT(factor->value, 0); + TVM_FFI_ICHECK_GT(factor->value, 0); repr << factor->value; } - ICHECK_EQ(axis->var.get()->name_hint.size(), 1) + TVM_FFI_ICHECK_EQ(axis->var.get()->name_hint.size(), 1) << "Invalid layout axis " << axis->var.get()->name_hint; char c = axis->var.get()->name_hint.operator std::string()[0]; - ICHECK((c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z')) << "Invalid layout axis " << c; + TVM_FFI_ICHECK((c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z')) << "Invalid layout axis " << c; repr << axis->var.get()->name_hint; } node->name = repr.str(); @@ -94,7 +94,7 @@ Layout::Layout(const ffi::Array& axes) { } Layout::Layout(const std::string& name, DataType dtype) { // NOLINT(*) - CHECK(dtype.is_int()) << "TypeError: The input dtype should be integer type"; + TVM_FFI_CHECK(dtype.is_int(), TypeError) << "The input dtype should be integer type"; if (name == "__undef__") return; auto node = ffi::make_object(); @@ -106,25 +106,25 @@ Layout::Layout(const std::string& name, DataType dtype) { // NOLINT(*) int32_t factor = 0; for (char c : name) { if (c >= 'A' && c <= 'Z') { - ICHECK_EQ(factor, 0) << "Invalid layout " << name << ": invalid factor size " << factor - << " before dimension " << c; + TVM_FFI_ICHECK_EQ(factor, 0) << "Invalid layout " << name << ": invalid factor size " + << factor << " before dimension " << c; std::string shape_name("_shape"); shape_name.insert(0, 1, c); IterVar axis(Range(IntImm(dtype, 0), Var(shape_name, dtype)), Var(std::string(1, c), dtype), tir::kDataPar); node->axes.push_back(axis); } else if (c >= 'a' && c <= 'z') { - ICHECK_GT(factor, 0) << "Invalid layout " << name << ": invalid factor size " << factor - << " for dimension " << c; + TVM_FFI_ICHECK_GT(factor, 0) << "Invalid layout " << name << ": invalid factor size " + << factor << " for dimension " << c; IterVar axis(Range(IntImm(dtype, 0), IntImm(dtype, factor)), Var(std::string(1, c), dtype), tir::kDataPar); node->axes.push_back(axis); factor = 0; } else if (c >= '0' && c <= '9') { - ICHECK(factor >= 0) << "Invalid layout " << name << ": _ is adjacent to a number."; + TVM_FFI_ICHECK(factor >= 0) << "Invalid layout " << name << ": _ is adjacent to a number."; factor = factor * 10 + c - '0'; } else { - LOG(FATAL) << "Invalid layout " << name; + TVM_FFI_THROW(InternalError) << "Invalid layout " << name; } } @@ -132,15 +132,15 @@ Layout::Layout(const std::string& name, DataType dtype) { // NOLINT(*) std::vector exist_axis(256, false); for (const IterVar& v : node->axes) { auto axis_str = v->var.get()->name_hint.operator std::string(); - ICHECK_EQ(axis_str.size(), 1); + TVM_FFI_ICHECK_EQ(axis_str.size(), 1); char axis = axis_str[0]; - ICHECK((axis >= 'a' && axis <= 'z') || (axis >= 'A' && axis <= 'Z')); + TVM_FFI_ICHECK((axis >= 'a' && axis <= 'z') || (axis >= 'A' && axis <= 'Z')); exist_axis[axis] = true; } for (const IterVar& v : node->axes) { char axis = v->var.get()->name_hint.operator std::string()[0]; if (axis >= 'a' && axis <= 'z') { - ICHECK(exist_axis[axis - 'a' + 'A']) + TVM_FFI_ICHECK(exist_axis[axis - 'a' + 'A']) << "Invalid layout " << name << ": missing axis " << std::toupper(axis); } } @@ -163,13 +163,13 @@ Layout Layout::Split(const LayoutAxis& axis, size_t target_pos, int32_t factor) if (!defined()) return Layout::Undef(); const std::string& name = operator->()->name; const auto axes = operator->()->axes; - ICHECK(target_pos <= this->ndim()) + TVM_FFI_ICHECK(target_pos <= this->ndim()) << "Invalid split position " << target_pos << " for layout " << name; - ICHECK(axis.IsPrimal()) << "Cannot split a subordinate axis " << axis; - ICHECK(this->Contains(axis)) << "Axis " << axis << " does not exist in " << name; - ICHECK(!this->Contains(axis.ToSubordinate())) + TVM_FFI_ICHECK(axis.IsPrimal()) << "Cannot split a subordinate axis " << axis; + TVM_FFI_ICHECK(this->Contains(axis)) << "Axis " << axis << " does not exist in " << name; + TVM_FFI_ICHECK(!this->Contains(axis.ToSubordinate())) << "Axis " << axis << " has already been split in " << name; - ICHECK(factor > 0) << "Invalid split size " << factor; + TVM_FFI_ICHECK(factor > 0) << "Invalid split size " << factor; ffi::Array new_layout; for (size_t i = 0; i <= this->ndim(); ++i) { if (i == target_pos) { @@ -192,7 +192,7 @@ int32_t Layout::FactorOf(const LayoutAxis& axis) const { if (sub == LayoutAxis::Get(itvar)) { has_sub = true; int32_t val = itvar->dom->extent.as()->value; - ICHECK(val); + TVM_FFI_ICHECK(val); factor *= val; } } @@ -310,17 +310,17 @@ inline ffi::Array TransformIndex(const ffi::Array& src_index } ffi::Array BijectiveLayout::ForwardIndex(const ffi::Array& src_index) const { - ICHECK(defined()) << "Cannot operate on an undefined bijective layout."; + TVM_FFI_ICHECK(defined()) << "Cannot operate on an undefined bijective layout."; const BijectiveLayoutNode* self = operator->(); - ICHECK_EQ(src_index.size(), self->src_layout->axes.size()) + TVM_FFI_ICHECK_EQ(src_index.size(), self->src_layout->axes.size()) << "Input mismatch with layout " << self->src_layout; return TransformIndex(src_index, self->src_layout->axes, self->index_forward_rule); } ffi::Array BijectiveLayout::BackwardIndex(const ffi::Array& dst_index) const { - ICHECK(defined()) << "Cannot operate on an undefined bijective layout."; + TVM_FFI_ICHECK(defined()) << "Cannot operate on an undefined bijective layout."; const BijectiveLayoutNode* self = operator->(); - ICHECK_EQ(dst_index.size(), self->dst_layout->axes.size()) + TVM_FFI_ICHECK_EQ(dst_index.size(), self->dst_layout->axes.size()) << "Output mismatch with layout " << self->dst_layout; return TransformIndex(dst_index, self->dst_layout->axes, self->index_backward_rule); } @@ -330,7 +330,7 @@ inline ffi::Array TransformShape(const ffi::Array& src_shape const ffi::Array& target_axis, const ffi::Array& transform_rule) { arith::Analyzer ana; - ICHECK_EQ(src_shape.size(), src_axis.size()) + TVM_FFI_ICHECK_EQ(src_shape.size(), src_axis.size()) << "Input shape size " << src_shape.size() << " mismatch with the expected shape size " << src_axis.size(); // bind variables for original axes @@ -346,7 +346,7 @@ inline ffi::Array TransformShape(const ffi::Array& src_shape const auto* orig_shape_const = orig_shape.as(); const auto* orig_axis_extent = orig_axis->dom->extent.as(); if (orig_shape_const) { - ICHECK_EQ(orig_shape_const->value, orig_axis_extent->value) + TVM_FFI_ICHECK_EQ(orig_shape_const->value, orig_axis_extent->value) << "Input shape mismatch at index " << i << ". Expected " << orig_axis->dom->extent << ", get " << orig_shape; } @@ -362,7 +362,7 @@ inline ffi::Array TransformShape(const ffi::Array& src_shape // for major-axis, use the forward/backward_rule directly, // for minor-axis, simply use the extent. ffi::Array result; - ICHECK_EQ(transform_rule.size(), target_axis.size()); + TVM_FFI_ICHECK_EQ(transform_rule.size(), target_axis.size()); for (size_t i = 0; i < transform_rule.size(); ++i) { PrimExpr rule = transform_rule[i]; IterVar axis = target_axis[i]; @@ -396,14 +396,14 @@ inline ffi::Array TransformShape(const ffi::Array& src_shape } ffi::Array BijectiveLayout::ForwardShape(const ffi::Array& shape) const { - ICHECK(defined()) << "Cannot operate on an undefined bijective layout."; + TVM_FFI_ICHECK(defined()) << "Cannot operate on an undefined bijective layout."; const BijectiveLayoutNode* self = operator->(); return TransformShape(shape, self->src_layout->axes, self->dst_layout->axes, self->shape_forward_rule); } ffi::Array BijectiveLayout::BackwardShape(const ffi::Array& shape) const { - ICHECK(defined()) << "Cannot operate on an undefined bijective layout."; + TVM_FFI_ICHECK(defined()) << "Cannot operate on an undefined bijective layout."; const BijectiveLayoutNode* self = operator->(); return TransformShape(shape, self->dst_layout->axes, self->src_layout->axes, self->shape_backward_rule); @@ -417,8 +417,8 @@ BijectiveLayout::BijectiveLayout(Layout src_layout, Layout dst_layout) { // To be consistent with previous behavior, a nullptr layout is created // when argument is invalid. if (GetStoreRule(&n->index_forward_rule, &n->shape_forward_rule, n->src_layout, n->dst_layout)) { - ICHECK(GetStoreRule(&n->index_backward_rule, &n->shape_backward_rule, n->dst_layout, - n->src_layout)); + TVM_FFI_ICHECK(GetStoreRule(&n->index_backward_rule, &n->shape_backward_rule, n->dst_layout, + n->src_layout)); data_ = std::move(n); } } diff --git a/src/s_tir/meta_schedule/arg_info.cc b/src/s_tir/meta_schedule/arg_info.cc index e71177b8af42..bfc9c1ee527d 100644 --- a/src/s_tir/meta_schedule/arg_info.cc +++ b/src/s_tir/meta_schedule/arg_info.cc @@ -56,12 +56,12 @@ inline tir::PrimFunc FindEntryFunc(const IRModule& mod) { } // Priority 3: The only PrimFunc in the IRModule if (num_prim_func == 0) { - LOG(FATAL) << "ValueError: Cannot find any PrimFunc in the given IRModule: " << mod; + TVM_FFI_THROW(ValueError) << "Cannot find any PrimFunc in the given IRModule: " << mod; } if (num_prim_func > 1) { - LOG(FATAL) << "ValueError: Multiple PrimFuncs exist in the IRModule, but none of them are " - "annotated with `kIsEntryFunc`, i.e. `tir.is_entry_func`" - << mod; + TVM_FFI_THROW(ValueError) << "Multiple PrimFuncs exist in the IRModule, but none of them are " + "annotated with `kIsEntryFunc`, i.e. `tir.is_entry_func`" + << mod; } return ffi::GetRef(last_func); } @@ -74,17 +74,17 @@ ArgInfo ArgInfo::FromJSON(const ObjectRef& json_obj) { ffi::Optional tag{std::nullopt}; try { const ffi::ArrayObj* json_array = json_obj.as(); - CHECK(json_array && json_array->size() >= 1); + TVM_FFI_ICHECK(json_array && json_array->size() >= 1); tag = json_array->at(0).cast(); } catch (const std::runtime_error& e) { // includes tvm::Error and dmlc::Error - LOG(FATAL) << "ValueError: Unable to parse the JSON object: " << json_obj - << "\nThe error is: " << e.what(); + TVM_FFI_THROW(ValueError) << "Unable to parse the JSON object: " << json_obj + << "\nThe error is: " << e.what(); } // Step 2. Dispatch the tag to corresponding subclass of ArgInfo if (tag == "TENSOR") { return TensorInfo::FromJSON(json_obj); } - LOG(FATAL) << "ValueError: Unable to parse the JSON object: " << json_obj; + TVM_FFI_THROW(ValueError) << "Unable to parse the JSON object: " << json_obj; throw; } @@ -98,7 +98,7 @@ ffi::Array ArgInfo::FromPrimFunc(const tir::PrimFunc& func) { result.push_back(TensorInfo(/*dtype=*/buffer->dtype, /*shape=*/AsVector(buffer->shape))); } else { - LOG(FATAL) << "ValueError: Unsupported argument type: " << arg; + TVM_FFI_THROW(ValueError) << "Unsupported argument type: " << arg; } } return result; @@ -133,7 +133,7 @@ TensorInfo TensorInfo::FromJSON(const ObjectRef& json_obj) { ffi::Array shape; try { const ffi::ArrayObj* json_array = json_obj.as(); - CHECK(json_array && json_array->size() == 3); + TVM_FFI_ICHECK(json_array && json_array->size() == 3); // Load json[1] => dtype { ffi::String dtype_str = json_array->at(1).cast(); @@ -142,8 +142,8 @@ TensorInfo TensorInfo::FromJSON(const ObjectRef& json_obj) { // Load json[2] => shape shape = AsIntArray(json_array->at(2).cast()); } catch (const std::runtime_error& e) { // includes tvm::Error and dmlc::Error - LOG(FATAL) << "ValueError: Unable to parse the JSON object: " << json_obj - << "\nThe error is: " << e.what(); + TVM_FFI_THROW(ValueError) << "Unable to parse the JSON object: " << json_obj + << "\nThe error is: " << e.what(); } std::vector s; std::transform(shape.begin(), shape.end(), std::back_inserter(s), @@ -156,7 +156,7 @@ TensorInfo TensorInfo::FromJSON(const ObjectRef& json_obj) { TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& n, ReprPrinter* p) { const auto* self = n.as(); - ICHECK(self); + TVM_FFI_ICHECK(self); p->stream << "TensorInfo(\"" << self->dtype << "\", " << self->shape << ")"; }); diff --git a/src/s_tir/meta_schedule/cost_model/cost_model.cc b/src/s_tir/meta_schedule/cost_model/cost_model.cc index 20ac8bfefbf1..34d0cf5d77d2 100644 --- a/src/s_tir/meta_schedule/cost_model/cost_model.cc +++ b/src/s_tir/meta_schedule/cost_model/cost_model.cc @@ -25,25 +25,25 @@ namespace s_tir { namespace meta_schedule { void PyCostModelNode::Load(const ffi::String& path) { - ICHECK(f_load != nullptr) << "PyCostModel's Load method not implemented!"; + TVM_FFI_ICHECK(f_load != nullptr) << "PyCostModel's Load method not implemented!"; f_load(path); } void PyCostModelNode::Save(const ffi::String& path) { - ICHECK(f_save != nullptr) << "PyCostModel's Save method not implemented!"; + TVM_FFI_ICHECK(f_save != nullptr) << "PyCostModel's Save method not implemented!"; f_save(path); } void PyCostModelNode::Update(const TuneContext& context, const ffi::Array& candidates, const ffi::Array& results) { - ICHECK(f_update != nullptr) << "PyCostModel's Update method not implemented!"; + TVM_FFI_ICHECK(f_update != nullptr) << "PyCostModel's Update method not implemented!"; f_update(context, candidates, results); } std::vector PyCostModelNode::Predict(const TuneContext& context, const ffi::Array& candidates) { - ICHECK(f_predict != nullptr) << "PyCostModel's Predict method not implemented!"; + TVM_FFI_ICHECK(f_predict != nullptr) << "PyCostModel's Predict method not implemented!"; std::vector result(candidates.size(), 0.0); f_predict(context, candidates, result.data()); return result; @@ -66,9 +66,9 @@ CostModel CostModel::PyCostModel(PyCostModelNode::FLoad f_load, // TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& n, ReprPrinter* p) { const auto* self = n.as(); - ICHECK(self); + TVM_FFI_ICHECK(self); PyCostModelNode::FAsString f_as_string = (*self).f_as_string; - ICHECK(f_as_string != nullptr) << "PyCostModel's AsString method not implemented!"; + TVM_FFI_ICHECK(f_as_string != nullptr) << "PyCostModel's AsString method not implemented!"; p->stream << f_as_string(); }); diff --git a/src/s_tir/meta_schedule/database/database.cc b/src/s_tir/meta_schedule/database/database.cc index 87558e0611f4..c3cd0aca0e2b 100644 --- a/src/s_tir/meta_schedule/database/database.cc +++ b/src/s_tir/meta_schedule/database/database.cc @@ -55,7 +55,7 @@ Workload Workload::FromJSON(const ObjectRef& json_obj) { THashCode shash = 0; try { const ffi::ArrayObj* json_array = json_obj.as(); - CHECK(json_array && json_array->size() == 2); + TVM_FFI_ICHECK(json_array && json_array->size() == 2); // Load json[0] => shash ffi::String str_shash = json_array->at(0).cast(); // Load json[1] => mod @@ -66,8 +66,8 @@ Workload Workload::FromJSON(const ObjectRef& json_obj) { std::stringstream(str_shash) >> shash; } } catch (const std::runtime_error& e) { // includes tvm::Error and dmlc::Error - LOG(FATAL) << "ValueError: Unable to parse the JSON object: " << json_obj - << "\nThe error is: " << e.what(); + TVM_FFI_THROW(ValueError) << "Unable to parse the JSON object: " << json_obj + << "\nThe error is: " << e.what(); } return Workload(mod, shash); } @@ -140,7 +140,7 @@ TuningRecord TuningRecord::FromJSON(const ObjectRef& json_obj, const Workload& w ffi::Optional> args_info; try { const ffi::ArrayObj* json_array = json_obj.as(); - CHECK(json_array && json_array->size() == 4); + TVM_FFI_ICHECK(json_array && json_array->size() == 4); // Load json[1] => run_secs if (json_array->at(1) != nullptr) { run_secs = AsFloatArray(json_array->at(1).cast()); @@ -169,8 +169,8 @@ TuningRecord TuningRecord::FromJSON(const ObjectRef& json_obj, const Workload& w trace = sch->trace().value(); } } catch (const std::runtime_error& e) { // includes tvm::Error and dmlc::Error - LOG(FATAL) << "ValueError: Unable to parse the JSON object: " << json_obj - << "\nThe error is: " << e.what(); + TVM_FFI_THROW(ValueError) << "Unable to parse the JSON object: " << json_obj + << "\nThe error is: " << e.what(); } return TuningRecord(trace, workload, run_secs, target, args_info); } @@ -191,7 +191,7 @@ ffi::Optional DatabaseNode::QueryTuningRecord(const IRModule& mod, if (records.empty()) { return std::nullopt; } - ICHECK_EQ(records.size(), 1); + TVM_FFI_ICHECK_EQ(records.size(), 1); return records[0]; } diff --git a/src/s_tir/meta_schedule/database/database_utils.cc b/src/s_tir/meta_schedule/database/database_utils.cc index 0fa92f51c2a7..120c380aa2e3 100644 --- a/src/s_tir/meta_schedule/database/database_utils.cc +++ b/src/s_tir/meta_schedule/database/database_utils.cc @@ -64,8 +64,8 @@ void JSONDumps(Any json_obj, std::ostringstream& os) { if (auto key = kv.first.try_cast()) { key_values.emplace_back(key.value(), kv.second); } else { - LOG(FATAL) << "TypeError: Only string keys are supported in JSON dumps, but got: " - << kv.first.GetTypeKey(); + TVM_FFI_THROW(TypeError) << "Only string keys are supported in JSON dumps, but got: " + << kv.first.GetTypeKey(); } } std::sort(key_values.begin(), key_values.end(), @@ -84,7 +84,7 @@ void JSONDumps(Any json_obj, std::ostringstream& os) { } else if (json_obj.as()) { JSONDumps(ffi::String(SaveJSON(json_obj)), os); } else { - LOG(FATAL) << "TypeError: Unsupported type in JSON object: " << json_obj.GetTypeKey(); + TVM_FFI_THROW(TypeError) << "Unsupported type in JSON object: " << json_obj.GetTypeKey(); } } @@ -135,7 +135,7 @@ class JSONTokenizer { Token token; if (NextString(&token)) return token; if (NextNumber(&token)) return token; - LOG(FATAL) << "ValueError: Cannot tokenize: " << std::string(cur_, end_); + TVM_FFI_THROW(ValueError) << "Cannot tokenize: " << std::string(cur_, end_); throw; } @@ -205,7 +205,7 @@ class JSONTokenizer { } ++cur_; if (cur_ == end_) { - LOG(FATAL) << "ValueError: Unexpected end of string: \\"; + TVM_FFI_THROW(ValueError) << "Unexpected end of string: \\"; throw; } switch (*cur_) { @@ -234,12 +234,12 @@ class JSONTokenizer { str.push_back('\t'); break; default: - LOG(FATAL) << "ValueError: Unsupported escape sequence: \\" << *cur_ - << ". record:" << std::string(cur_, end_); + TVM_FFI_THROW(ValueError) << "Unsupported escape sequence: \\" << *cur_ + << ". record:" << std::string(cur_, end_); } } if (cur_ == end_) { - LOG(FATAL) << "ValueError: Unexpected end of string"; + TVM_FFI_THROW(ValueError) << "Unexpected end of string"; } ++cur_; *token = Token{TokenType::kString, ffi::String(str)}; @@ -302,15 +302,15 @@ class JSONParser { case TokenType::kFloat: return token.value; case TokenType::kRightSquare: - LOG(FATAL) << "ValueError: Unexpected token: ]"; + TVM_FFI_THROW(ValueError) << "Unexpected token: ]"; case TokenType::kRightCurly: - LOG(FATAL) << "ValueError: Unexpected token: }"; + TVM_FFI_THROW(ValueError) << "Unexpected token: }"; case TokenType::kComma: - LOG(FATAL) << "ValueError: Unexpected token: ,"; + TVM_FFI_THROW(ValueError) << "Unexpected token: ,"; case TokenType::kColon: - LOG(FATAL) << "ValueError: Unexpected token: :"; + TVM_FFI_THROW(ValueError) << "Unexpected token: :"; case TokenType::kEOF: - LOG(FATAL) << "ValueError: Unexpected EOF"; + TVM_FFI_THROW(ValueError) << "Unexpected EOF"; default: throw; } @@ -342,7 +342,7 @@ class JSONParser { results.push_back(ParseObject(std::move(token))); continue; } else { - LOG(FATAL) << "ValueError: Unexpected token before: " << tokenizer_.cur_; + TVM_FFI_THROW(ValueError) << "Unexpected token before: " << tokenizer_.cur_; } } return results; @@ -372,15 +372,16 @@ class JSONParser { } // Case 3 Any key = ParseObject(std::move(token)); - ICHECK(key.as()) << "ValueError: key must be a string, but gets: " << key; + TVM_FFI_CHECK(key.as(), ValueError) + << "key must be a string, but gets: " << key; token = tokenizer_.Next(); - CHECK(token.type == TokenType::kColon) - << "ValueError: Unexpected token before: " << tokenizer_.cur_; + TVM_FFI_CHECK(token.type == TokenType::kColon, ValueError) + << "Unexpected token before: " << tokenizer_.cur_; Any value = ParseObject(tokenizer_.Next()); results.Set(Downcast(key), value); continue; } else { - LOG(FATAL) << "ValueError: Unexpected token before: " << tokenizer_.cur_; + TVM_FFI_THROW(ValueError) << "Unexpected token before: " << tokenizer_.cur_; } } return results; diff --git a/src/s_tir/meta_schedule/database/json_database.cc b/src/s_tir/meta_schedule/database/json_database.cc index 608aef93a108..f1bac43da206 100644 --- a/src/s_tir/meta_schedule/database/json_database.cc +++ b/src/s_tir/meta_schedule/database/json_database.cc @@ -51,9 +51,9 @@ std::vector JSONFileReadLines(const ffi::String& path, int num_threads, boo }); return json_objs; } - CHECK(allow_missing) << "ValueError: File doesn't exist: " << path; + TVM_FFI_CHECK(allow_missing, ValueError) << "File doesn't exist: " << path; std::ofstream os(path); - CHECK(os.good()) << "ValueError: Cannot create new file: " << path; + TVM_FFI_CHECK(os.good(), ValueError) << "Cannot create new file: " << path; return {}; } @@ -64,7 +64,7 @@ std::vector JSONFileReadLines(const ffi::String& path, int num_threads, boo */ void JSONFileAppendLine(const ffi::String& path, const std::string& line) { std::ofstream os(path, std::ofstream::app); - CHECK(os.good()) << "ValueError: Cannot open the file to write: " << path; + TVM_FFI_CHECK(os.good(), ValueError) << "Cannot open the file to write: " << path; os << line << std::endl; } @@ -122,7 +122,7 @@ class JSONDatabaseNode : public DatabaseNode { } ffi::Array GetTopK(const Workload& workload, int top_k) { - CHECK_GE(top_k, 0) << "ValueError: top_k must be non-negative"; + TVM_FFI_CHECK_GE(top_k, 0, ValueError) << "top_k must be non-negative"; if (top_k == 0) { return {}; } @@ -192,18 +192,19 @@ Database Database::JSONDatabase(ffi::String path_workload, ffi::String path_tuni Workload workload{ffi::UnsafeInit()}; try { const ffi::ArrayObj* arr = json_obj.as(); - ICHECK_EQ(arr->size(), 2); + TVM_FFI_ICHECK_EQ(arr->size(), 2); int64_t workload_index = arr->at(0).cast()->value; - ICHECK(workload_index >= 0 && static_cast(workload_index) < workloads.size()); + TVM_FFI_ICHECK(workload_index >= 0 && + static_cast(workload_index) < workloads.size()); workload = workloads[workload_index]; records[task_id] = TuningRecord::FromJSON(arr->at(1).cast(), workload); } catch (std::runtime_error& e) { - LOG(FATAL) << "ValueError: Unable to parse TuningRecord, on line " << (task_id + 1) - << " of file " << path_tuning_record << ". The workload is:\n" - << (workload.defined() ? workload->mod->Script() : "(null)") - << "\nThe JSONObject of TuningRecord is:\n" - << json_obj << "\nThe error message is:\n" - << e.what(); + TVM_FFI_THROW(ValueError) << "Unable to parse TuningRecord, on line " << (task_id + 1) + << " of file " << path_tuning_record << ". The workload is:\n" + << (workload.defined() ? workload->mod->Script() : "(null)") + << "\nThe JSONObject of TuningRecord is:\n" + << json_obj << "\nThe error message is:\n" + << e.what(); } }); for (const TuningRecord& record : records) { diff --git a/src/s_tir/meta_schedule/database/memory_database.cc b/src/s_tir/meta_schedule/database/memory_database.cc index 669f48079df3..d7cdb8473cf2 100644 --- a/src/s_tir/meta_schedule/database/memory_database.cc +++ b/src/s_tir/meta_schedule/database/memory_database.cc @@ -65,7 +65,7 @@ class MemoryDatabaseNode : public DatabaseNode { void CommitTuningRecord(const TuningRecord& record) final { records.push_back(record); } ffi::Array GetTopK(const Workload& workload, int top_k) final { - CHECK_GE(top_k, 0) << "ValueError: top_k must be non-negative"; + TVM_FFI_CHECK_GE(top_k, 0, ValueError) << "top_k must be non-negative"; if (top_k == 0) { return {}; } diff --git a/src/s_tir/meta_schedule/database/ordered_union_database.cc b/src/s_tir/meta_schedule/database/ordered_union_database.cc index fe687435758b..5b08b24d391e 100644 --- a/src/s_tir/meta_schedule/database/ordered_union_database.cc +++ b/src/s_tir/meta_schedule/database/ordered_union_database.cc @@ -48,32 +48,32 @@ class OrderedUnionDatabaseNode : public DatabaseNode { } bool HasWorkload(const IRModule& mod) final { - LOG(FATAL) << "NotImplementedError: OrderedUnionDatabase.HasWorkload"; + TVM_FFI_THROW(NotImplementedError) << "OrderedUnionDatabase.HasWorkload"; throw; } Workload CommitWorkload(const IRModule& mod) final { - LOG(FATAL) << "NotImplementedError: OrderedUnionDatabase.CommitWorkload"; + TVM_FFI_THROW(NotImplementedError) << "OrderedUnionDatabase.CommitWorkload"; throw; } void CommitTuningRecord(const TuningRecord& record) final { - LOG(FATAL) << "NotImplementedError: OrderedUnionDatabase.CommitTuningRecord"; + TVM_FFI_THROW(NotImplementedError) << "OrderedUnionDatabase.CommitTuningRecord"; throw; } ffi::Array GetTopK(const Workload& workload, int top_k) final { - LOG(FATAL) << "NotImplementedError: OrderedUnionDatabase.GetTopK"; + TVM_FFI_THROW(NotImplementedError) << "OrderedUnionDatabase.GetTopK"; throw; } ffi::Array GetAllTuningRecords() final { - LOG(FATAL) << "NotImplementedError: OrderedUnionDatabase.GetAllTuningRecords"; + TVM_FFI_THROW(NotImplementedError) << "OrderedUnionDatabase.GetAllTuningRecords"; throw; } int64_t Size() final { - LOG(FATAL) << "NotImplementedError: OrderedUnionDatabase.size"; + TVM_FFI_THROW(NotImplementedError) << "OrderedUnionDatabase.size"; throw; } }; diff --git a/src/s_tir/meta_schedule/database/schedule_fn_database.cc b/src/s_tir/meta_schedule/database/schedule_fn_database.cc index 261312679565..e7259a2f7274 100644 --- a/src/s_tir/meta_schedule/database/schedule_fn_database.cc +++ b/src/s_tir/meta_schedule/database/schedule_fn_database.cc @@ -66,32 +66,32 @@ class ScheduleFnDatabaseNode : public DatabaseNode { } bool HasWorkload(const IRModule& mod) final { - LOG(FATAL) << "NotImplementedError: ScheduleFnDatabase.HasWorkload"; + TVM_FFI_THROW(NotImplementedError) << "ScheduleFnDatabase.HasWorkload"; throw; } Workload CommitWorkload(const IRModule& mod) final { - LOG(FATAL) << "NotImplementedError: ScheduleFnDatabase.CommitWorkload"; + TVM_FFI_THROW(NotImplementedError) << "ScheduleFnDatabase.CommitWorkload"; throw; } void CommitTuningRecord(const TuningRecord& record) final { - LOG(FATAL) << "NotImplementedError: ScheduleFnDatabase.CommitTuningRecord"; + TVM_FFI_THROW(NotImplementedError) << "ScheduleFnDatabase.CommitTuningRecord"; throw; } ffi::Array GetTopK(const Workload& workload, int top_k) final { - LOG(FATAL) << "NotImplementedError: ScheduleFnDatabase.GetTopK"; + TVM_FFI_THROW(NotImplementedError) << "ScheduleFnDatabase.GetTopK"; throw; } ffi::Array GetAllTuningRecords() final { - LOG(FATAL) << "NotImplementedError: ScheduleFnDatabase.GetAllTuningRecords"; + TVM_FFI_THROW(NotImplementedError) << "ScheduleFnDatabase.GetAllTuningRecords"; throw; } int64_t Size() final { - LOG(FATAL) << "NotImplementedError: ScheduleFnDatabase.size"; + TVM_FFI_THROW(NotImplementedError) << "ScheduleFnDatabase.size"; throw; } }; diff --git a/src/s_tir/meta_schedule/database/union_database.cc b/src/s_tir/meta_schedule/database/union_database.cc index 4d0a117b413a..e2c8c4234650 100644 --- a/src/s_tir/meta_schedule/database/union_database.cc +++ b/src/s_tir/meta_schedule/database/union_database.cc @@ -50,32 +50,32 @@ class UnionDatabaseNode : public DatabaseNode { } bool HasWorkload(const IRModule& mod) final { - LOG(FATAL) << "NotImplementedError: UnionDatabase.HasWorkload"; + TVM_FFI_THROW(NotImplementedError) << "UnionDatabase.HasWorkload"; throw; } Workload CommitWorkload(const IRModule& mod) final { - LOG(FATAL) << "NotImplementedError: UnionDatabase.CommitWorkload"; + TVM_FFI_THROW(NotImplementedError) << "UnionDatabase.CommitWorkload"; throw; } void CommitTuningRecord(const TuningRecord& record) final { - LOG(FATAL) << "NotImplementedError: UnionDatabase.CommitTuningRecord"; + TVM_FFI_THROW(NotImplementedError) << "UnionDatabase.CommitTuningRecord"; throw; } ffi::Array GetTopK(const Workload& workload, int top_k) final { - LOG(FATAL) << "NotImplementedError: UnionDatabase.GetTopK"; + TVM_FFI_THROW(NotImplementedError) << "UnionDatabase.GetTopK"; throw; } ffi::Array GetAllTuningRecords() final { - LOG(FATAL) << "NotImplementedError: UnionDatabase.GetAllTuningRecords"; + TVM_FFI_THROW(NotImplementedError) << "UnionDatabase.GetAllTuningRecords"; throw; } int64_t Size() final { - LOG(FATAL) << "NotImplementedError: UnionDatabase.size"; + TVM_FFI_THROW(NotImplementedError) << "UnionDatabase.size"; throw; } }; diff --git a/src/s_tir/meta_schedule/feature_extractor/feature_extractor.cc b/src/s_tir/meta_schedule/feature_extractor/feature_extractor.cc index a0fb2f710490..e037231943cb 100644 --- a/src/s_tir/meta_schedule/feature_extractor/feature_extractor.cc +++ b/src/s_tir/meta_schedule/feature_extractor/feature_extractor.cc @@ -26,7 +26,8 @@ namespace meta_schedule { ffi::Array PyFeatureExtractorNode::ExtractFrom( const TuneContext& context, const ffi::Array& candidates) { - ICHECK(f_extract_from != nullptr) << "PyFeatureExtractor's ExtractFrom method not implemented!"; + TVM_FFI_ICHECK(f_extract_from != nullptr) + << "PyFeatureExtractor's ExtractFrom method not implemented!"; return f_extract_from(context, candidates); } @@ -42,9 +43,10 @@ FeatureExtractor FeatureExtractor::PyFeatureExtractor( TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& n, ReprPrinter* p) { const auto* self = n.as(); - ICHECK(self); + TVM_FFI_ICHECK(self); PyFeatureExtractorNode::FAsString f_as_string = (*self).f_as_string; - ICHECK(f_as_string != nullptr) << "PyFeatureExtractor's AsString method not implemented!"; + TVM_FFI_ICHECK(f_as_string != nullptr) + << "PyFeatureExtractor's AsString method not implemented!"; p->stream << f_as_string(); }); diff --git a/src/s_tir/meta_schedule/feature_extractor/per_store_feature.cc b/src/s_tir/meta_schedule/feature_extractor/per_store_feature.cc index 908c4155c1b2..038d65217f50 100644 --- a/src/s_tir/meta_schedule/feature_extractor/per_store_feature.cc +++ b/src/s_tir/meta_schedule/feature_extractor/per_store_feature.cc @@ -205,7 +205,7 @@ int64_t GetVarStride(const std::vector& multi_indices, const IntVec& // Calculate the min stride possible int64_t result = kNotFound; for (const MultiIndex& multi_index : multi_indices) { - ICHECK_EQ(multi_index.size(), buffer_stride.size()); + TVM_FFI_ICHECK_EQ(multi_index.size(), buffer_stride.size()); // Find the rightest dimension that contains the given variable for (int i = ndim - 1; i >= 0; --i) { int64_t coef = CoefficientExtractor::Extract(multi_index[i], var); @@ -228,7 +228,7 @@ int64_t GetVarStride(const std::vector& multi_indices, const IntVec& */ runtime::Tensor AsTensor(const std::vector>& src, int second_dim_size = -1) { int n = src.size(); - ICHECK(!src.empty() || second_dim_size != -1); + TVM_FFI_ICHECK(!src.empty() || second_dim_size != -1); int m = src.empty() ? second_dim_size : src[0].size(); runtime::Tensor tgt = runtime::Tensor::Empty( /*shape=*/{n, m}, @@ -382,7 +382,7 @@ struct LoopNest { } else if (support::StartsWith(thread_tag, "vthread")) { ref_loops = &vthread; } else { - LOG(FATAL) << "ValueError: Unable to recognize thread tag: " << thread_tag; + TVM_FFI_THROW(ValueError) << "Unable to recognize thread tag: " << thread_tag; } } if (ref_loops != nullptr) { @@ -875,7 +875,7 @@ void Feature::SubFeature::SetStride(const LoopNest& loop_nest, arith::Analyzer* { int64_t& num_continuous_bytes = this->num_continuous_bytes = 1; const IntVec& access_shape = this->access_shape; - ICHECK_EQ(access_shape.size(), buffer_shape.size()); + TVM_FFI_ICHECK_EQ(access_shape.size(), buffer_shape.size()); for (int i = ndim - 1; i >= 0; --i) { if (access_shape[i] == buffer_shape[i]) { num_continuous_bytes = buffer_shape[i] * buffer->dtype.bytes(); @@ -1077,7 +1077,7 @@ struct Feature { const group1::Feature::ArithOps& arith_ops) : arith_intensity_curve(n_samples, 0.0) { const std::vector& loops = loop_nest.loops; - ICHECK_EQ(loops.size(), for_touched_bytes.size()); + TVM_FFI_ICHECK_EQ(loops.size(), for_touched_bytes.size()); int n_loops = loops.size(); // Calculate `memory_bytes` std::vector memory_bytes; @@ -1115,7 +1115,7 @@ struct Feature { break; } } - CHECK_LT(p, n_loops); + TVM_FFI_ICHECK_LT(p, n_loops); if (p == 0) { result = slog(compute_ops[p] / memory_bytes[p]); } else { @@ -1287,10 +1287,10 @@ class PerStoreFeatureCollector : private StmtVisitor { for (auto& it : collector.buffer_features_) { Feature& feature = it.second; if (feature.buffer != nullptr) { - ICHECK(feature.group1); - ICHECK(feature.group2); - ICHECK(feature.group3); - ICHECK(feature.group5); + TVM_FFI_ICHECK(feature.group1); + TVM_FFI_ICHECK(feature.group2); + TVM_FFI_ICHECK(feature.group3); + TVM_FFI_ICHECK(feature.group5); if (feature.group4 == nullptr) { feature.group4 = std::make_unique(); } diff --git a/src/s_tir/meta_schedule/measure_callback/add_to_database.cc b/src/s_tir/meta_schedule/measure_callback/add_to_database.cc index 39291114ad0b..a71fe3ca2ae4 100644 --- a/src/s_tir/meta_schedule/measure_callback/add_to_database.cc +++ b/src/s_tir/meta_schedule/measure_callback/add_to_database.cc @@ -38,7 +38,7 @@ class AddToDatabaseNode : public MeasureCallbackNode { Database database = task_scheduler->database_.value(); Workload workload = database->CommitWorkload(task->mod.value()); Target target = task->target.value(); - ICHECK_EQ(runner_results.size(), measure_candidates.size()); + TVM_FFI_ICHECK_EQ(runner_results.size(), measure_candidates.size()); int n = runner_results.size(); for (int i = 0; i < n; ++i) { RunnerResult result = runner_results[i]; diff --git a/src/s_tir/meta_schedule/measure_callback/measure_callback.cc b/src/s_tir/meta_schedule/measure_callback/measure_callback.cc index 4658c477e932..9f2a3056c258 100644 --- a/src/s_tir/meta_schedule/measure_callback/measure_callback.cc +++ b/src/s_tir/meta_schedule/measure_callback/measure_callback.cc @@ -29,7 +29,7 @@ void PyMeasureCallbackNode::Apply(const TaskScheduler& task_scheduler, const ffi::Array& measure_candidates, // const ffi::Array& builds, // const ffi::Array& results) { - ICHECK(f_apply != nullptr) << "PyMeasureCallback's Apply method not implemented!"; + TVM_FFI_ICHECK(f_apply != nullptr) << "PyMeasureCallback's Apply method not implemented!"; auto _ = Profiler::TimedScope("MeasureCallback/" + this->f_as_string()); return f_apply(task_scheduler, task_id, measure_candidates, builds, results); } @@ -53,9 +53,10 @@ ffi::Array MeasureCallback::Default() { TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& n, ReprPrinter* p) { const auto* self = n.as(); - ICHECK(self); + TVM_FFI_ICHECK(self); PyMeasureCallbackNode::FAsString f_as_string = (*self).f_as_string; - ICHECK(f_as_string != nullptr) << "PyMeasureCallback's AsString method not implemented!"; + TVM_FFI_ICHECK(f_as_string != nullptr) + << "PyMeasureCallback's AsString method not implemented!"; p->stream << f_as_string(); }); diff --git a/src/s_tir/meta_schedule/measure_callback/update_cost_model.cc b/src/s_tir/meta_schedule/measure_callback/update_cost_model.cc index 91499bb30bf2..08a02352c63d 100644 --- a/src/s_tir/meta_schedule/measure_callback/update_cost_model.cc +++ b/src/s_tir/meta_schedule/measure_callback/update_cost_model.cc @@ -36,9 +36,10 @@ class UpdateCostModelNode : public MeasureCallbackNode { return; } CostModel cost_model = task_scheduler->cost_model_.value(); - ICHECK(task->measure_candidates.defined()) << "Task's measure candidates must be present!"; - ICHECK_EQ(measure_candidates.size(), builder_results.size()); - ICHECK_EQ(runner_results.size(), builder_results.size()); + TVM_FFI_ICHECK(task->measure_candidates.defined()) + << "Task's measure candidates must be present!"; + TVM_FFI_ICHECK_EQ(measure_candidates.size(), builder_results.size()); + TVM_FFI_ICHECK_EQ(runner_results.size(), builder_results.size()); int n = builder_results.size(); ffi::Array pruned_candidate; ffi::Array pruned_runner_result; diff --git a/src/s_tir/meta_schedule/module_equality.cc b/src/s_tir/meta_schedule/module_equality.cc index 033db234e1c3..6973ba809627 100644 --- a/src/s_tir/meta_schedule/module_equality.cc +++ b/src/s_tir/meta_schedule/module_equality.cc @@ -85,7 +85,7 @@ std::unique_ptr ModuleEquality::Create(const std::string& mod_eq } else if (mod_eq_name == "anchor-block") { return std::make_unique(); } - LOG(FATAL) << "Unknown module equality " << mod_eq_name; + TVM_FFI_THROW(InternalError) << "Unknown module equality " << mod_eq_name; } } // namespace meta_schedule diff --git a/src/s_tir/meta_schedule/mutator/mutate_compute_location.cc b/src/s_tir/meta_schedule/mutator/mutate_compute_location.cc index c5fe144d84a6..67fd867e48b9 100644 --- a/src/s_tir/meta_schedule/mutator/mutate_compute_location.cc +++ b/src/s_tir/meta_schedule/mutator/mutate_compute_location.cc @@ -92,7 +92,7 @@ std::vector MutateComputeLocationNode::Fin const Any& decision) -> Any { if (inst->kind.same_as(inst_sample_compute_location)) { // Step 1. Extract the instruction input and the old decision. - ICHECK_EQ(inputs.size(), 1); + TVM_FFI_ICHECK_EQ(inputs.size(), 1); tir::StmtSRef block_sref = sch->GetSRef(Downcast(inputs[0])); int old_decision = Downcast(decision)->value; @@ -104,7 +104,7 @@ std::vector MutateComputeLocationNode::Fin location_srefs.erase(location_srefs.begin() + (it - location_indices.begin())); location_indices.erase(it); } - ICHECK_EQ(location_srefs.size(), location_indices.size()); + TVM_FFI_ICHECK_EQ(location_srefs.size(), location_indices.size()); // Step 4. Add a new candidate if there are at least one remaining compute-at position. if (!location_srefs.empty()) { candidates.emplace_back(inst, std::move(location_indices)); diff --git a/src/s_tir/meta_schedule/mutator/mutate_parallel.cc b/src/s_tir/meta_schedule/mutator/mutate_parallel.cc index e3377fc9c21a..2bd43e0a650d 100644 --- a/src/s_tir/meta_schedule/mutator/mutate_parallel.cc +++ b/src/s_tir/meta_schedule/mutator/mutate_parallel.cc @@ -38,7 +38,7 @@ bool IsAnnotateWithParallel(const Instruction& inst) { if (!inst->kind.same_as(inst_annotate)) { return false; } - ICHECK_EQ(inst->attrs.size(), 1); + TVM_FFI_ICHECK_EQ(inst->attrs.size(), 1); ffi::String ann_key = Downcast(inst->attrs[0]); return ann_key == tir::attr::meta_schedule_parallel; } @@ -50,7 +50,7 @@ bool IsAnnotateWithParallel(const Instruction& inst) { * \return The replaced instruction */ Instruction ReplaceAnnValue(Instruction inst, int64_t ann_val) { - ICHECK_EQ(inst->inputs.size(), 2); + TVM_FFI_ICHECK_EQ(inst->inputs.size(), 2); return Instruction(/*kind=*/inst->kind, // /*inputs=*/{inst->inputs[0], Integer(ann_val)}, // /*attrs=*/inst->attrs, @@ -67,7 +67,7 @@ const SBlockRVNode* GetInstGetSBlockOutput(const Instruction& inst) { if (!inst->kind.same_as(inst_get_sblock)) { return nullptr; } - ICHECK_EQ(inst->outputs.size(), 1); + TVM_FFI_ICHECK_EQ(inst->outputs.size(), 1); const SBlockRVNode* block = TVM_TYPE_AS(inst->outputs[0], SBlockRVNode); return block; } @@ -85,7 +85,7 @@ std::vector> AnalyzeParallel(const ScheduleState& self, const ffi::String& func_name, int64_t limit) { ffi::Array block_srefs = GetSBlocks(self, block_name, self->mod->GetGlobalVar(func_name)); - ICHECK_EQ(block_srefs.size(), 1); + TVM_FFI_ICHECK_EQ(block_srefs.size(), 1); const SBlockNode* block = TVM_SREF_TO_SBLOCK(block_srefs[0]); ScopeBlockLoopInfo info = GetScopeBlockLoopInfo(ffi::GetRef(block)); std::vector> results; @@ -239,10 +239,10 @@ bool FindParallelDecision(const Trace& trace, TRandState* rand_state, return false; } const InstructionNode* ann_inst = ann_insts[s_tir::SampleInt(rand_state, 0, n_ann_insts)]; - ICHECK_EQ(ann_inst->inputs.size(), 2); + TVM_FFI_ICHECK_EQ(ann_inst->inputs.size(), 2); const InstructionNode* get_sblock_inst = get_sblock_insts.at(Downcast(ann_inst->inputs[0]).get()); - ICHECK_EQ(get_sblock_inst->attrs.size(), 2); + TVM_FFI_ICHECK_EQ(get_sblock_inst->attrs.size(), 2); candidate->inst = ffi::GetRef(ann_inst); candidate->parallel_extent = Downcast(ann_inst->inputs[1])->value; candidate->block_name = Downcast(get_sblock_inst->attrs[0]); diff --git a/src/s_tir/meta_schedule/mutator/mutate_thread_binding.cc b/src/s_tir/meta_schedule/mutator/mutate_thread_binding.cc index 15ff27bd457e..d1310907b016 100644 --- a/src/s_tir/meta_schedule/mutator/mutate_thread_binding.cc +++ b/src/s_tir/meta_schedule/mutator/mutate_thread_binding.cc @@ -100,7 +100,7 @@ std::vector MutateThreadBindingNode::FindCan } // Only consider cases with 2 factors and the first one is None if (inst->inputs.size() != 3 || inst->inputs[1] != nullptr) return false; - ICHECK(inst->inputs[2] != nullptr); + TVM_FFI_ICHECK(inst->inputs[2] != nullptr); return sample_insts.find(Downcast(inst->inputs[2]).get()) != sample_insts.end(); }; @@ -109,8 +109,8 @@ std::vector MutateThreadBindingNode::FindCan if (!inst->kind.same_as(inst_bind)) { return false; } - ICHECK_EQ(inst->inputs.size(), 1); - ICHECK_EQ(inst->attrs.size(), 1); + TVM_FFI_ICHECK_EQ(inst->inputs.size(), 1); + TVM_FFI_ICHECK_EQ(inst->attrs.size(), 1); if (Downcast(inst->attrs[0]) != "threadIdx.x") return false; return sampled_split_insts.find(Downcast(inst->inputs[0]).get()) != @@ -119,11 +119,11 @@ std::vector MutateThreadBindingNode::FindCan for (const Instruction& inst : trace->insts) { if (inst->kind.same_as(inst_sample_categorical)) { - ICHECK_EQ(inst->outputs.size(), 1); + TVM_FFI_ICHECK_EQ(inst->outputs.size(), 1); const PrimExprNode* var_rv = TVM_TYPE_AS(inst->outputs[0], PrimExprNode); sample_insts[var_rv] = inst.get(); } else if (is_split_by_sample(inst)) { - CHECK_EQ(inst->outputs.size(), 2); + TVM_FFI_ICHECK_EQ(inst->outputs.size(), 2); // Only consider the inner loop, which can be bound to threadIdx.x const s_tir::LoopRVNode* var_rv = TVM_TYPE_AS(inst->outputs[1], s_tir::LoopRVNode); sampled_split_insts[var_rv] = inst.get(); @@ -135,12 +135,12 @@ std::vector MutateThreadBindingNode::FindCan for (const InstructionNode* bind_inst : bind_insts) { const auto* loop_rv = TVM_TYPE_AS(bind_inst->inputs[0], s_tir::LoopRVNode); auto split_it = sampled_split_insts.find(loop_rv); - ICHECK(split_it != sampled_split_insts.end()); + TVM_FFI_ICHECK(split_it != sampled_split_insts.end()); const InstructionNode* split_inst = split_it->second; const auto* expr_rv = TVM_TYPE_AS(split_inst->inputs[2], PrimExprNode); auto sample_it = sample_insts.find(expr_rv); - ICHECK(sample_it != sample_insts.end()); + TVM_FFI_ICHECK(sample_it != sample_insts.end()); const InstructionNode* sample_inst = sample_it->second; int decision = Downcast(trace->decisions[ffi::GetRef(sample_inst)])->value; diff --git a/src/s_tir/meta_schedule/mutator/mutate_tile_size.cc b/src/s_tir/meta_schedule/mutator/mutate_tile_size.cc index 7e10fe7a3920..d6a43607c0a0 100644 --- a/src/s_tir/meta_schedule/mutator/mutate_tile_size.cc +++ b/src/s_tir/meta_schedule/mutator/mutate_tile_size.cc @@ -117,11 +117,11 @@ void FindSampleVectorize(const Trace& trace, std::vector* inst, // Find annotation with `meta_schedule_cooperative_fetch` for (const Instruction& inst : trace->insts) { if (inst->kind.same_as(inst_annotate)) { - ICHECK_EQ(inst->attrs.size(), 1); - ICHECK_EQ(inst->inputs.size(), 2); + TVM_FFI_ICHECK_EQ(inst->attrs.size(), 1); + TVM_FFI_ICHECK_EQ(inst->inputs.size(), 2); if (Downcast(inst->attrs[0]) == tir::attr::meta_schedule_cooperative_fetch) { const auto* ann_val = inst->inputs[1].as(); - ICHECK(ann_val); + TVM_FFI_ICHECK(ann_val); annotated.insert(ann_val); } } @@ -130,9 +130,9 @@ void FindSampleVectorize(const Trace& trace, std::vector* inst, for (const auto& kv : trace->decisions) { const Instruction& inst = kv.first; if (inst->kind.same_as(inst_sample_categorical)) { - ICHECK_EQ(inst->outputs.size(), 1); + TVM_FFI_ICHECK_EQ(inst->outputs.size(), 1); if (annotated.count(inst->outputs[0].as())) { - ICHECK_EQ(inst->attrs.size(), 2); + TVM_FFI_ICHECK_EQ(inst->attrs.size(), 2); std::vector probs = support::AsVector(Downcast>(inst->attrs[1])); if (probs.size() == 1) { @@ -237,7 +237,7 @@ ffi::Optional MutateSampleTileSize(const Trace& trace, Instruction inst, ffi::Optional MutateSampleVectorize(const Trace& trace, Instruction inst, int64_t original_decision, TRandState* rand_state) { - ICHECK_EQ(inst->attrs.size(), 2); + TVM_FFI_ICHECK_EQ(inst->attrs.size(), 2); std::vector probs = support::AsVector(Downcast>(inst->attrs[1])); probs.erase(probs.begin() + original_decision); diff --git a/src/s_tir/meta_schedule/mutator/mutate_unroll.cc b/src/s_tir/meta_schedule/mutator/mutate_unroll.cc index 749ad43bcb49..47fed617b655 100644 --- a/src/s_tir/meta_schedule/mutator/mutate_unroll.cc +++ b/src/s_tir/meta_schedule/mutator/mutate_unroll.cc @@ -35,7 +35,7 @@ bool IsAnnotateWithUnroll(const Instruction& inst) { if (!inst->kind.same_as(inst_annotate)) { return false; } - ICHECK_EQ(inst->attrs.size(), 1); + TVM_FFI_ICHECK_EQ(inst->attrs.size(), 1); ffi::String ann_key = Downcast(inst->attrs[0]); return ann_key == tir::attr::meta_schedule_unroll_explicit || ann_key == tir::attr::meta_schedule_unroll_implicit; @@ -102,7 +102,7 @@ bool FindUnrollDecision(const Trace& trace, TRandState* rand_state, ann_insts.reserve(trace->insts.size()); for (const Instruction& inst : trace->insts) { if (inst->kind.same_as(inst_sample_categorical)) { - ICHECK_EQ(inst->outputs.size(), 1); + TVM_FFI_ICHECK_EQ(inst->outputs.size(), 1); const PrimExprNode* var_rv = TVM_TYPE_AS(inst->outputs[0], PrimExprNode); sample_insts[var_rv] = inst.get(); } else if (IsAnnotateWithUnroll(inst)) { @@ -114,11 +114,11 @@ bool FindUnrollDecision(const Trace& trace, TRandState* rand_state, return false; } const InstructionNode* ann_inst = ann_insts[s_tir::SampleInt(rand_state, 0, n_ann_insts)]; - ICHECK_EQ(ann_inst->inputs.size(), 2); + TVM_FFI_ICHECK_EQ(ann_inst->inputs.size(), 2); const auto* var_rv = TVM_TYPE_AS(ann_inst->inputs[1], PrimExprNode); - ICHECK(sample_insts.count(var_rv)); + TVM_FFI_ICHECK(sample_insts.count(var_rv)); const InstructionNode* sample_inst = sample_insts.at(var_rv); - ICHECK_EQ(sample_inst->attrs.size(), 2); + TVM_FFI_ICHECK_EQ(sample_inst->attrs.size(), 2); candidate->inst = ffi::GetRef(sample_inst); candidate->decision = Downcast(trace->decisions[ffi::GetRef(sample_inst)])->value; diff --git a/src/s_tir/meta_schedule/mutator/mutator.cc b/src/s_tir/meta_schedule/mutator/mutator.cc index ca89cbd3ae63..8821a239b4c9 100644 --- a/src/s_tir/meta_schedule/mutator/mutator.cc +++ b/src/s_tir/meta_schedule/mutator/mutator.cc @@ -25,19 +25,19 @@ namespace s_tir { namespace meta_schedule { void PyMutatorNode::InitializeWithTuneContext(const TuneContext& context) { - ICHECK(f_initialize_with_tune_context != nullptr) + TVM_FFI_ICHECK(f_initialize_with_tune_context != nullptr) << "PyMutator's InitializeWithTuneContext method not implemented!"; f_initialize_with_tune_context(context); } ffi::Optional PyMutatorNode::Apply( const s_tir::Trace& trace, support::LinearCongruentialEngine::TRandState* rand_state) { - ICHECK(f_apply != nullptr) << "PyMutator's Apply method not implemented!"; + TVM_FFI_ICHECK(f_apply != nullptr) << "PyMutator's Apply method not implemented!"; return f_apply(trace, *rand_state); } Mutator PyMutatorNode::Clone() const { - ICHECK(f_clone != nullptr) << "PyMutator's Clone method not implemented!"; + TVM_FFI_ICHECK(f_clone != nullptr) << "PyMutator's Clone method not implemented!"; return f_clone(); } @@ -82,9 +82,9 @@ ffi::Map Mutator::DefaultHexagon() { TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& n, ReprPrinter* p) { const auto* self = n.as(); - ICHECK(self); + TVM_FFI_ICHECK(self); PyMutatorNode::FAsString f_as_string = (*self).f_as_string; - ICHECK(f_as_string != nullptr) << "PyMutator's AsString method not implemented!"; + TVM_FFI_ICHECK(f_as_string != nullptr) << "PyMutator's AsString method not implemented!"; p->stream << f_as_string(); }); diff --git a/src/s_tir/meta_schedule/postproc/disallow_async_strided_mem_copy.cc b/src/s_tir/meta_schedule/postproc/disallow_async_strided_mem_copy.cc index 7b2184f83c4a..753c929dd656 100644 --- a/src/s_tir/meta_schedule/postproc/disallow_async_strided_mem_copy.cc +++ b/src/s_tir/meta_schedule/postproc/disallow_async_strided_mem_copy.cc @@ -126,7 +126,7 @@ class DisallowAsyncStridedMemCopyNode : public PostprocNode { // Inherited from PostprocNode void InitializeWithTuneContext(const TuneContext& context) final { /* Null check */ - ICHECK(context->target) << "Context must contain a target"; + TVM_FFI_ICHECK(context->target) << "Context must contain a target"; this->target = context->target.value(); } // Inherited from PostprocNode diff --git a/src/s_tir/meta_schedule/postproc/postproc.cc b/src/s_tir/meta_schedule/postproc/postproc.cc index 85fa7bb29a79..ac8f73f260d7 100644 --- a/src/s_tir/meta_schedule/postproc/postproc.cc +++ b/src/s_tir/meta_schedule/postproc/postproc.cc @@ -25,18 +25,18 @@ namespace s_tir { namespace meta_schedule { void PyPostprocNode::InitializeWithTuneContext(const TuneContext& context) { - ICHECK(f_initialize_with_tune_context != nullptr) + TVM_FFI_ICHECK(f_initialize_with_tune_context != nullptr) << "PyPostproc's InitializeWithTuneContext method not implemented!"; f_initialize_with_tune_context(context); } bool PyPostprocNode::Apply(const s_tir::Schedule& sch) { - ICHECK(f_apply != nullptr) << "PyPostproc's Apply method not implemented!"; + TVM_FFI_ICHECK(f_apply != nullptr) << "PyPostproc's Apply method not implemented!"; return f_apply(sch); } Postproc PyPostprocNode::Clone() const { - ICHECK(f_clone != nullptr) << "PyPostproc's Clone method not implemented!"; + TVM_FFI_ICHECK(f_clone != nullptr) << "PyPostproc's Clone method not implemented!"; return f_clone(); } @@ -114,9 +114,9 @@ ffi::Array Postproc::DefaultHexagon() { TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& n, ReprPrinter* p) { const auto* self = n.as(); - ICHECK(self); + TVM_FFI_ICHECK(self); PyPostprocNode::FAsString f_as_string = (*self).f_as_string; - ICHECK(f_as_string != nullptr) << "PyPostproc's AsString method not implemented!"; + TVM_FFI_ICHECK(f_as_string != nullptr) << "PyPostproc's AsString method not implemented!"; p->stream << f_as_string(); }); diff --git a/src/s_tir/meta_schedule/postproc/rewrite_cooperative_fetch.cc b/src/s_tir/meta_schedule/postproc/rewrite_cooperative_fetch.cc index 817f2ac94827..151a50b4a0ad 100644 --- a/src/s_tir/meta_schedule/postproc/rewrite_cooperative_fetch.cc +++ b/src/s_tir/meta_schedule/postproc/rewrite_cooperative_fetch.cc @@ -37,8 +37,8 @@ ffi::Optional ParseThreadBinding(const Schedule& sch, const Instruction if (!inst->kind.same_as(inst_kind_bind)) { return std::nullopt; } - ICHECK_EQ(inst->inputs.size(), 1); - ICHECK_EQ(inst->attrs.size(), 1); + TVM_FFI_ICHECK_EQ(inst->inputs.size(), 1); + TVM_FFI_ICHECK_EQ(inst->attrs.size(), 1); ffi::String thread_axis = Downcast(inst->attrs[0]); if (thread_axis != axis) { return std::nullopt; @@ -59,8 +59,8 @@ ffi::Optional ParseAnnotate(const Schedule& sch, const Instruction& in if (!inst->kind.same_as(inst_kind_annotate)) { return std::nullopt; } - ICHECK_EQ(inst->inputs.size(), 2); - ICHECK_EQ(inst->attrs.size(), 1); + TVM_FFI_ICHECK_EQ(inst->inputs.size(), 2); + TVM_FFI_ICHECK_EQ(inst->attrs.size(), 1); ffi::String ann_key = Downcast(inst->attrs[0]); if (ann_key != tir::attr::meta_schedule_cooperative_fetch) { return std::nullopt; @@ -80,8 +80,8 @@ bool ParseWarpExecutionAnn(const Schedule& sch, const Instruction& inst) { if (!inst->kind.same_as(inst_kind_annotate)) { return false; } - ICHECK_EQ(inst->inputs.size(), 2); - ICHECK_EQ(inst->attrs.size(), 1); + TVM_FFI_ICHECK_EQ(inst->inputs.size(), 2); + TVM_FFI_ICHECK_EQ(inst->attrs.size(), 1); ffi::String ann_key = Downcast(inst->attrs[0]); return ann_key == tir::attr::warp_execution; } diff --git a/src/s_tir/meta_schedule/postproc/rewrite_layout.cc b/src/s_tir/meta_schedule/postproc/rewrite_layout.cc index 6ad7f65c3c18..78000e54ed21 100644 --- a/src/s_tir/meta_schedule/postproc/rewrite_layout.cc +++ b/src/s_tir/meta_schedule/postproc/rewrite_layout.cc @@ -54,7 +54,7 @@ class BufferReadPosCollector : public StmtExprVisitor { } void VisitExpr_(const BufferLoadNode* op) final { - CHECK(cur_realize_.defined()) << "BufferLoad occurred outside of any block"; + TVM_FFI_ICHECK(cur_realize_.defined()) << "BufferLoad occurred outside of any block"; const Buffer& buffer = op->buffer; if (buffer_ == buffer.get()) { @@ -74,7 +74,7 @@ class BufferReadPosCollector : public StmtExprVisitor { /*predicate=*/cur_realize_->predicate, // /*analyzer=*/&analyzer_); int buffer_index = GetReadBufferIndex(cur_realize_->block, buffer); - ICHECK(buffer_index != -1); + TVM_FFI_ICHECK(buffer_index != -1); buffer_loc_ = std::make_pair(cur_realize_->block, buffer_index); } } @@ -125,7 +125,7 @@ ffi::Array CollectLayoutFreeBuffers(const PrimFuncNode* func) { ffi::Array layout_free_buffers; for (const Integer& index : layout_free_buffer_index) { - ICHECK(static_cast(index->value) < func->params.size()); + TVM_FFI_ICHECK(static_cast(index->value) < func->params.size()); const Var& param = func->params[index->value]; layout_free_buffers.push_back(func->buffer_map.at(param)); } @@ -215,7 +215,7 @@ bool RewriteLayout(const Schedule& sch) { // for a cache-read buffer that is directly consumed by an anchor op. The last buffer // in cache_read_chain corresponds to that buffer. SBlock cache_read_block = sch->Get(sch->GetSBlock(cache_read_chain.back(), func_name)); - ICHECK_EQ(cache_read_block->writes.size(), 1); + TVM_FFI_ICHECK_EQ(cache_read_block->writes.size(), 1); auto tup_opt = GetSuggestedIndexMap(cache_read_block->writes[0]->buffer, prim_func); if (tup_opt == std::nullopt) continue; diff --git a/src/s_tir/meta_schedule/postproc/rewrite_parallel_vectorize_unroll.cc b/src/s_tir/meta_schedule/postproc/rewrite_parallel_vectorize_unroll.cc index ce2cf12502b3..099f2f449cbf 100644 --- a/src/s_tir/meta_schedule/postproc/rewrite_parallel_vectorize_unroll.cc +++ b/src/s_tir/meta_schedule/postproc/rewrite_parallel_vectorize_unroll.cc @@ -151,7 +151,7 @@ void RemoveParsedAnn(const Schedule& sch, const SBlockRV& block_rv, int CalculateNumRewritableLoops(const ffi::Array& loop_srefs, const std::vector& loop_types) { int rw_loops_num = 0; - ICHECK_EQ(loop_srefs.size(), loop_types.size()); + TVM_FFI_ICHECK_EQ(loop_srefs.size(), loop_types.size()); for (size_t i = 0; i < loop_srefs.size(); ++i) { const StmtSRef& loop_sref = loop_srefs[i]; const ForNode* loop = TVM_SREF_TO_FOR(loop_sref); @@ -364,7 +364,7 @@ void RewriteFuseSplitParallelVectorize(const Schedule& sch, ffi::Array* size_t n_loops = loop_rvs->size(); LoopRV fused = sch->Fuse({loop_rvs->begin(), loop_rvs->end()}); ffi::Array split = sch->Split(fused, {std::nullopt, Integer(vec_len)}); - ICHECK_EQ(split.size(), 2); + TVM_FFI_ICHECK_EQ(split.size(), 2); const LoopRV& outer = split[0]; const LoopRV& inner = split[1]; sch->Parallel(outer); @@ -376,7 +376,7 @@ void RewriteFuseSplitParallelVectorize(const Schedule& sch, ffi::Array* } void RewriteParallel(const Schedule& sch, size_t n, ffi::Array* loop_rvs) { - ICHECK_LE(n, loop_rvs->size()); + TVM_FFI_ICHECK_LE(n, loop_rvs->size()); LoopRV fused = sch->Fuse({loop_rvs->begin(), loop_rvs->begin() + n}); sch->Parallel(fused); for (size_t i = 0; i < n; ++i) { @@ -386,7 +386,7 @@ void RewriteParallel(const Schedule& sch, size_t n, ffi::Array* loop_rvs void RewriteVectorize(const Schedule& sch, size_t n, ffi::Array* loop_rvs) { size_t n_loops = loop_rvs->size(); - ICHECK_LE(n, n_loops); + TVM_FFI_ICHECK_LE(n, n_loops); LoopRV fused = sch->Fuse({loop_rvs->end() - n, loop_rvs->end()}); sch->Vectorize(fused); for (size_t i = n_loops - n; i < n_loops; ++i) { @@ -445,7 +445,7 @@ class RewriteParallelVectorizeUnrollNode : public PostprocNode { } // AutoUnroll if (parsed.unroll_explicit != -1 || parsed.unroll_implicit != -1) { - ICHECK(parsed.unroll_explicit == -1 || parsed.unroll_implicit == -1); + TVM_FFI_ICHECK(parsed.unroll_explicit == -1 || parsed.unroll_implicit == -1); int unroll_explicit = parsed.unroll_explicit != -1; int max_step = parsed.unroll_explicit + parsed.unroll_implicit + 1; s_tir::RewriteUnroll(sch, unroll_explicit, max_step, block_rv, loop_rvs[0]); diff --git a/src/s_tir/meta_schedule/postproc/rewrite_reduction_block.cc b/src/s_tir/meta_schedule/postproc/rewrite_reduction_block.cc index 36dd9a6f2e9d..b6bc04c3f656 100644 --- a/src/s_tir/meta_schedule/postproc/rewrite_reduction_block.cc +++ b/src/s_tir/meta_schedule/postproc/rewrite_reduction_block.cc @@ -66,7 +66,7 @@ struct ReductionBlockFinder : private StmtVisitor { } auto f_find = [this](const VarNode* var) -> bool { return thread_bound_loop_vars_.count(var); }; const SBlockNode* block = realize->block.get(); - ICHECK_EQ(block->iter_vars.size(), realize->iter_values.size()); + TVM_FFI_ICHECK_EQ(block->iter_vars.size(), realize->iter_values.size()); int n = block->iter_vars.size(); for (int i = 0; i < n; ++i) { IterVar iter_var = block->iter_vars[i]; diff --git a/src/s_tir/meta_schedule/postproc/rewrite_tensorize.cc b/src/s_tir/meta_schedule/postproc/rewrite_tensorize.cc index c633504d3e6a..926aed03cd2b 100644 --- a/src/s_tir/meta_schedule/postproc/rewrite_tensorize.cc +++ b/src/s_tir/meta_schedule/postproc/rewrite_tensorize.cc @@ -51,9 +51,9 @@ void CollectTensorizationJobs( } else if (block_name.find("init") && vectorize_init_loop) { jobs->emplace_back(block_name, func_name, [sch](s_tir::SBlockRV block) { ffi::Array child_blocks = sch->GetChildBlocks(block); - ICHECK(child_blocks.size() == 1); + TVM_FFI_ICHECK(child_blocks.size() == 1); ffi::Array init_loops = sch->GetLoops(child_blocks[0]); - ICHECK(init_loops.size() == 1); + TVM_FFI_ICHECK(init_loops.size() == 1); sch->Vectorize(init_loops[0]); }); } diff --git a/src/s_tir/meta_schedule/postproc/rewrite_unbound_block.cc b/src/s_tir/meta_schedule/postproc/rewrite_unbound_block.cc index 578c3ad5ca2d..cf29cb503d98 100644 --- a/src/s_tir/meta_schedule/postproc/rewrite_unbound_block.cc +++ b/src/s_tir/meta_schedule/postproc/rewrite_unbound_block.cc @@ -90,11 +90,11 @@ class RewriteUnboundBlockNode : public PostprocNode { public: // Inherited from PostprocNode void InitializeWithTuneContext(const TuneContext& context) final { - CHECK(context->target.defined()) << "ValueError: target is not defined"; + TVM_FFI_CHECK(context->target.defined(), ValueError) << "target is not defined"; ffi::Optional max_threads_per_block = context->target.value()->GetAttr("max_threads_per_block"); - CHECK(max_threads_per_block.defined()) - << "ValueError: missing attribute `max_threads_per_block` in the target"; + TVM_FFI_CHECK(max_threads_per_block.defined(), ValueError) + << "missing attribute `max_threads_per_block` in the target"; this->max_threads_per_block_ = max_threads_per_block.value().IntValue(); } @@ -125,7 +125,7 @@ bool RewriteUnboundBlockNode::Apply(const s_tir::Schedule& sch) { using s_tir::LoopRV; using s_tir::SBlockRV; using s_tir::Schedule; - ICHECK_NE(this->max_threads_per_block_, -1); + TVM_FFI_ICHECK_NE(this->max_threads_per_block_, -1); auto get_factor = [t = this->max_threads_per_block_](int max_extent) -> ExprRV { return Integer(std::min(t, max_extent)); }; diff --git a/src/s_tir/meta_schedule/postproc/verify_gpu_code.cc b/src/s_tir/meta_schedule/postproc/verify_gpu_code.cc index bce597a6b777..fa71fb01311e 100644 --- a/src/s_tir/meta_schedule/postproc/verify_gpu_code.cc +++ b/src/s_tir/meta_schedule/postproc/verify_gpu_code.cc @@ -30,7 +30,7 @@ class ThreadExtentChecker : private StmtVisitor { public: static bool Check(const Stmt& stmt, int thread_warp_size) { try { - ICHECK(thread_warp_size > 0); + TVM_FFI_ICHECK(thread_warp_size > 0); ThreadExtentChecker checker(thread_warp_size); checker.VisitStmt(stmt); return true; @@ -106,11 +106,11 @@ namespace meta_schedule { /*! \brief Extract attribute from a target. */ Integer Extract(const Target& target, const char* name) { - ICHECK(target.defined()); + TVM_FFI_ICHECK(target.defined()); if (ffi::Optional v = target->GetAttr(name)) { return v.value(); } - LOG(FATAL) << "AttributedError: \"" << name << "\" is not defined in the target"; + TVM_FFI_THROW(AttributedError) << "\"" << name << "\" is not defined in the target"; throw; } @@ -122,7 +122,7 @@ class VerifyGPUCodeNode : public PostprocNode { int thread_warp_size_ = -1; void InitializeWithTuneContext(const TuneContext& context) final { - ICHECK(context->target.defined()); + TVM_FFI_ICHECK(context->target.defined()); this->target_ = context->target.value(); this->target_constraints_ = ffi::Map{ {"max_shared_memory_per_block", Extract(this->target_, "max_shared_memory_per_block")}, diff --git a/src/s_tir/meta_schedule/postproc/verify_vtcm_limit.cc b/src/s_tir/meta_schedule/postproc/verify_vtcm_limit.cc index 1e9ccc965255..4d4e0b3936ac 100644 --- a/src/s_tir/meta_schedule/postproc/verify_vtcm_limit.cc +++ b/src/s_tir/meta_schedule/postproc/verify_vtcm_limit.cc @@ -30,9 +30,9 @@ class VerifyVTCMLimitNode : public PostprocNode { Integer vtcm_capacity; void InitializeWithTuneContext(const TuneContext& context) final { - ICHECK(context->target.defined()); + TVM_FFI_ICHECK(context->target.defined()); Target target = context->target.value(); - ICHECK(target->kind->name == "hexagon"); + TVM_FFI_ICHECK(target->kind->name == "hexagon"); // The value of 0 will disable VTCM verification. vtcm_capacity = target->GetAttr("vtcm-capacity").value_or(0); } diff --git a/src/s_tir/meta_schedule/profiler.cc b/src/s_tir/meta_schedule/profiler.cc index f5f16c69522c..98dcfb158625 100644 --- a/src/s_tir/meta_schedule/profiler.cc +++ b/src/s_tir/meta_schedule/profiler.cc @@ -38,9 +38,10 @@ ffi::Map ProfilerNode::Get() const { } ffi::String ProfilerNode::Table() const { - CHECK(!stats_sec.empty()) << "ValueError: The stats are empty. Please run the profiler first."; - CHECK(stats_sec.count("Total")) - << "ValueError: The total time is not recorded. This method should be called only after " + TVM_FFI_CHECK(!stats_sec.empty(), ValueError) + << "The stats are empty. Please run the profiler first."; + TVM_FFI_CHECK(stats_sec.count("Total"), ValueError) + << "The total time is not recorded. This method should be called only after " "exiting the profiler's with scope."; double total = stats_sec.at("Total"); struct Entry { @@ -62,7 +63,7 @@ ffi::String ProfilerNode::Table() const { p.Separator(); for (int i = 0, n = table_entry.size(); i < n; ++i) { if (i == 0) { - p.Row() << "" << table_entry[i].name << table_entry[i].minutes << table_entry[i].percentage; + p.Row() << table_entry[i].name << table_entry[i].minutes << table_entry[i].percentage; } else { p.Row() << i << table_entry[i].name << table_entry[i].minutes << table_entry[i].percentage; } diff --git a/src/s_tir/meta_schedule/schedule/cpu/winograd.cc b/src/s_tir/meta_schedule/schedule/cpu/winograd.cc index 0d9b65af38d2..6c2839877e94 100644 --- a/src/s_tir/meta_schedule/schedule/cpu/winograd.cc +++ b/src/s_tir/meta_schedule/schedule/cpu/winograd.cc @@ -35,19 +35,19 @@ static ffi::Array ScheduleDataPack(s_tir::Schedule sch, s_tir::SB std::vector tiled, std::vector unrolled) { using namespace tvm::tir; - ICHECK_EQ(tiled.size(), 2); - ICHECK_EQ(unrolled.size(), 4); + TVM_FFI_ICHECK_EQ(tiled.size(), 2); + TVM_FFI_ICHECK_EQ(unrolled.size(), 4); ffi::Array factors{ffi::UnsafeInit()}; ffi::Array loops = sch->GetLoops(block); - ICHECK_EQ(loops.size(), 6); + TVM_FFI_ICHECK_EQ(loops.size(), 6); factors = sch->SamplePerfectTile(loops[tiled[0]], /*n=*/2, /*max_innermost_factor=*/64); ffi::Array t0 = sch->Split(loops[tiled[0]], {factors.begin(), factors.end()}); - ICHECK_EQ(t0.size(), 2); + TVM_FFI_ICHECK_EQ(t0.size(), 2); factors = sch->SamplePerfectTile(loops[tiled[1]], /*n=*/2, /*max_innermost_factor=*/64); ffi::Array t1 = sch->Split(loops[tiled[1]], {factors.begin(), factors.end()}); - ICHECK_EQ(t1.size(), 2); + TVM_FFI_ICHECK_EQ(t1.size(), 2); sch->Unroll(loops[unrolled[0]]); sch->Unroll(loops[unrolled[1]]); diff --git a/src/s_tir/meta_schedule/schedule/cuda/thread_bind.cc b/src/s_tir/meta_schedule/schedule/cuda/thread_bind.cc index f29ea1e45f72..5c661cc8ad28 100644 --- a/src/s_tir/meta_schedule/schedule/cuda/thread_bind.cc +++ b/src/s_tir/meta_schedule/schedule/cuda/thread_bind.cc @@ -80,7 +80,7 @@ ffi::Array BindSpatialLoop(Schedule sch, LoopRV loop, int64_t max_thread } ExprRV factor = get_factor(std::min(extent, max_threads_per_block)); ffi::Array splits = sch->Split(loop, {std::nullopt, factor}); - ICHECK_EQ(splits.size(), 2); + TVM_FFI_ICHECK_EQ(splits.size(), 2); sch->Bind(splits[0], "blockIdx.x"); sch->Bind(splits[1], "threadIdx.x"); return {splits[0], splits[1]}; @@ -88,7 +88,7 @@ ffi::Array BindSpatialLoop(Schedule sch, LoopRV loop, int64_t max_thread ffi::Array splits = sch->Split(loop, {std::nullopt, Integer(max_threadblocks), // Integer(max_threads_per_block)}); - ICHECK_EQ(splits.size(), 3); + TVM_FFI_ICHECK_EQ(splits.size(), 3); sch->Reorder({splits[1], splits[2], splits[0]}); sch->Bind(splits[1], "blockIdx.x"); sch->Bind(splits[2], "threadIdx.x"); @@ -149,7 +149,7 @@ void BindBlockThreadIdx(Schedule sch, SBlockRV block_rv, // return; } if (i_block_idx != -1 && i_thread_idx == -1) { - ICHECK(false) << "Unsupported case, where blockIdx is bound but threadIdx is not"; + TVM_FFI_ICHECK(false) << "Unsupported case, where blockIdx is bound but threadIdx is not"; throw; } LoopRV loop_rv{ffi::UnsafeInit()}; diff --git a/src/s_tir/meta_schedule/schedule/cuda/winograd.cc b/src/s_tir/meta_schedule/schedule/cuda/winograd.cc index 941643404824..8ac211b338ed 100644 --- a/src/s_tir/meta_schedule/schedule/cuda/winograd.cc +++ b/src/s_tir/meta_schedule/schedule/cuda/winograd.cc @@ -39,19 +39,19 @@ static ffi::Array ScheduleDataPack(s_tir::Schedule sch, s_tir::SB std::vector unrolled) { // This method is used for NHWC layout only. Will likely be refactored into a more schedule using namespace tvm::tir; - ICHECK_EQ(tiled.size(), 2); - ICHECK_EQ(unrolled.size(), 4); + TVM_FFI_ICHECK_EQ(tiled.size(), 2); + TVM_FFI_ICHECK_EQ(unrolled.size(), 4); ffi::Array factors{ffi::UnsafeInit()}; ffi::Array loops = sch->GetLoops(block); - ICHECK_EQ(loops.size(), 6); + TVM_FFI_ICHECK_EQ(loops.size(), 6); factors = sch->SamplePerfectTile(loops[tiled[0]], /*n=*/2, /*max_innermost_factor=*/64); ffi::Array t0 = sch->Split(loops[tiled[0]], {factors.begin(), factors.end()}); - ICHECK_EQ(t0.size(), 2); + TVM_FFI_ICHECK_EQ(t0.size(), 2); factors = sch->SamplePerfectTile(loops[tiled[1]], /*n=*/2, /*max_innermost_factor=*/64); ffi::Array t1 = sch->Split(loops[tiled[1]], {factors.begin(), factors.end()}); - ICHECK_EQ(t1.size(), 2); + TVM_FFI_ICHECK_EQ(t1.size(), 2); sch->Unroll(loops[unrolled[0]]); sch->Unroll(loops[unrolled[1]]); @@ -91,7 +91,7 @@ TVM_FFI_STATIC_INIT_BLOCK() { int64_t max_threadblocks = 256; int64_t max_threads_per_block = 1024; ffi::Array loops = sch->GetLoops(data_pack); - ICHECK_EQ(loops.size(), 8); + TVM_FFI_ICHECK_EQ(loops.size(), 8); BindSpatialLoop(sch, sch->Fuse({loops[0], loops[1], loops[2], loops[3]}), max_threadblocks, max_threads_per_block); } @@ -104,7 +104,7 @@ TVM_FFI_STATIC_INIT_BLOCK() { int64_t max_threadblocks = 256; int64_t max_threads_per_block = 1024; ffi::Array loops = sch->GetLoops(inverse); - ICHECK_EQ(loops.size(), 8); + TVM_FFI_ICHECK_EQ(loops.size(), 8); BindSpatialLoop(sch, sch->Fuse({loops[0], loops[1], loops[2], loops[3]}), max_threadblocks, max_threads_per_block); return {sch}; @@ -118,7 +118,7 @@ TVM_FFI_STATIC_INIT_BLOCK() { LoopRV outer{ffi::UnsafeInit()}; { ffi::Array loops = sch->GetLoops(data_pack); - ICHECK_EQ(loops.size(), 6); + TVM_FFI_ICHECK_EQ(loops.size(), 6); sch->Reorder({loops[2], loops[3], loops[0], loops[1], loops[4], loops[5]}); sch->Unroll(loops[0]); sch->Unroll(loops[1]); @@ -149,7 +149,7 @@ TVM_FFI_STATIC_INIT_BLOCK() { { SBlockRV output = sch->GetConsumers(inverse)[0]; ffi::Array nchw = sch->GetLoops(output); - ICHECK_EQ(nchw.size(), 4); + TVM_FFI_ICHECK_EQ(nchw.size(), 4); ffi::Array hs = sch->Split(nchw[2], {std::nullopt, Integer(tile_size)}); ffi::Array ws = sch->Split(nchw[3], {std::nullopt, Integer(tile_size)}); sch->Reorder({hs[0], ws[0], hs[1], ws[1]}); @@ -159,7 +159,7 @@ TVM_FFI_STATIC_INIT_BLOCK() { sch->ComputeAt(inverse, /*loop_rv=*/outer, /*preserve_unit_loops=*/true); sch->SetScope(inverse, /*buffer_index=*/0, /*storage_scope=*/"local"); ffi::Array loops = sch->GetLoops(inverse); - ICHECK_EQ(loops.size(), 10); + TVM_FFI_ICHECK_EQ(loops.size(), 10); sch->Unroll(loops[6]); sch->Unroll(loops[7]); sch->Unroll(loops[8]); diff --git a/src/s_tir/meta_schedule/schedule/generic/winograd.cc b/src/s_tir/meta_schedule/schedule/generic/winograd.cc index 78f4ed19b708..4a5e25dbac11 100644 --- a/src/s_tir/meta_schedule/schedule/generic/winograd.cc +++ b/src/s_tir/meta_schedule/schedule/generic/winograd.cc @@ -43,7 +43,7 @@ SBlockRV GetWinogradProducerAndInlineConst(Schedule sch, SBlockRV block) { results.push_back(producer); } } - ICHECK_EQ(results.size(), 1); + TVM_FFI_ICHECK_EQ(results.size(), 1); return results[0]; } diff --git a/src/s_tir/meta_schedule/schedule_rule/add_rfactor.cc b/src/s_tir/meta_schedule/schedule_rule/add_rfactor.cc index 21909e752c83..f26fa9dc5127 100644 --- a/src/s_tir/meta_schedule/schedule_rule/add_rfactor.cc +++ b/src/s_tir/meta_schedule/schedule_rule/add_rfactor.cc @@ -28,7 +28,7 @@ class AddRFactorNode : public ScheduleRuleNode { public: // Inherited from ScheduleRuleNode void InitializeWithTuneContext(const TuneContext& context) final { - ICHECK(context->target.defined()); + TVM_FFI_ICHECK(context->target.defined()); Target target = context->target.value(); this->max_parallel_basic_ = GetTargetNumCores(target); if (this->max_jobs_per_core != -1) { @@ -110,7 +110,7 @@ ffi::Array AddRFactorNode::Apply(const s_tir::Schedule& sch, try { const s_tir::SBlockRV& block_rf = sch_tmp->RFactor(split_loop, num_spatial_loops); ffi::Array axes = sch_tmp->GetLoops(block_rf); - ICHECK_GT(axes.size(), num_spatial_loops); + TVM_FFI_ICHECK_GT(axes.size(), num_spatial_loops); // Annotate that the rfactor block, which is now the producer of the original block, needs to // be considered by the rule Random-Compute-Location. diff --git a/src/s_tir/meta_schedule/schedule_rule/apply_custom_rule.cc b/src/s_tir/meta_schedule/schedule_rule/apply_custom_rule.cc index 2a04cbe6be25..a944d8fa8eab 100644 --- a/src/s_tir/meta_schedule/schedule_rule/apply_custom_rule.cc +++ b/src/s_tir/meta_schedule/schedule_rule/apply_custom_rule.cc @@ -28,7 +28,8 @@ class ApplyCustomRuleNode : public ScheduleRuleNode { public: // Inherited from ScheduleRuleNode void InitializeWithTuneContext(const TuneContext& context) final { - CHECK(context->target.defined()) << "ValueError: Target is not defined in the tune context."; + TVM_FFI_CHECK(context->target.defined(), ValueError) + << "Target is not defined in the tune context."; this->target_ = context->target; } @@ -39,8 +40,8 @@ class ApplyCustomRuleNode : public ScheduleRuleNode { // Inherited from ScheduleRuleNode ffi::Array Apply(const s_tir::Schedule& sch, const s_tir::SBlockRV& block_rv) final { - CHECK(this->target_.defined()) - << "ValueError: ApplyCustomRule is not initialized with TuneContext that has a Target."; + TVM_FFI_CHECK(this->target_.defined(), ValueError) + << "ApplyCustomRule is not initialized with TuneContext that has a Target."; ffi::Array keys = this->target_.value()->keys; if (ffi::Optional ann = s_tir::GetAnn(sch->GetSRef(block_rv), "schedule_rule")) { diff --git a/src/s_tir/meta_schedule/schedule_rule/auto_bind.cc b/src/s_tir/meta_schedule/schedule_rule/auto_bind.cc index 7b650643ec4a..0a0f41dcee7f 100644 --- a/src/s_tir/meta_schedule/schedule_rule/auto_bind.cc +++ b/src/s_tir/meta_schedule/schedule_rule/auto_bind.cc @@ -32,11 +32,11 @@ class AutoBindNode : public ScheduleRuleNode { public: // Inherited from ScheduleRuleNode void InitializeWithTuneContext(const TuneContext& context) final { - CHECK(context->target.defined()) << "ValueError: target is not defined"; + TVM_FFI_CHECK(context->target.defined(), ValueError) << "target is not defined"; ffi::Optional max_threads_per_block = context->target.value()->GetAttr("max_threads_per_block"); - CHECK(max_threads_per_block.defined()) - << "ValueError: missing attribute `max_threads_per_block` in the target"; + TVM_FFI_CHECK(max_threads_per_block.defined(), ValueError) + << "missing attribute `max_threads_per_block` in the target"; this->max_threads_per_block_ = max_threads_per_block.value().IntValue(); } @@ -67,7 +67,7 @@ class AutoBindNode : public ScheduleRuleNode { ffi::Array AutoBindNode::Apply(const s_tir::Schedule& sch, const s_tir::SBlockRV& block_rv) { - ICHECK_NE(this->max_threads_per_block_, -1); + TVM_FFI_ICHECK_NE(this->max_threads_per_block_, -1); auto get_factor = MakeFactorSampler(sch, this->thread_extents_); BindBlockThreadIdx(sch, block_rv, max_threadblocks_, max_threads_per_block_, get_factor); return {sch}; diff --git a/src/s_tir/meta_schedule/schedule_rule/auto_inline.cc b/src/s_tir/meta_schedule/schedule_rule/auto_inline.cc index 263f880c1a57..b4433f160c8a 100644 --- a/src/s_tir/meta_schedule/schedule_rule/auto_inline.cc +++ b/src/s_tir/meta_schedule/schedule_rule/auto_inline.cc @@ -52,7 +52,7 @@ bool IsInSpatialPrimFunc(const s_tir::Schedule& sch, const tir::StmtSRef& block_ const StmtSRefNode* sref = block_sref.get(); for (; sref->parent != nullptr; sref = sref->parent) { } - ICHECK(sref->stmt != nullptr && sref->stmt->IsInstance()); + TVM_FFI_ICHECK(sref->stmt != nullptr && sref->stmt->IsInstance()); return IsSpatialPrimFunc(ffi::GetRef(GetRootPrimFunc(sch->mod(), sref->stmt, nullptr))); } diff --git a/src/s_tir/meta_schedule/schedule_rule/cross_thread_reduction.cc b/src/s_tir/meta_schedule/schedule_rule/cross_thread_reduction.cc index 047505eb6fcb..12e83f6c078b 100644 --- a/src/s_tir/meta_schedule/schedule_rule/cross_thread_reduction.cc +++ b/src/s_tir/meta_schedule/schedule_rule/cross_thread_reduction.cc @@ -28,7 +28,7 @@ class CrossThreadReductionNode : public ScheduleRuleNode { public: // Inherited from ScheduleRuleNode void InitializeWithTuneContext(const TuneContext& context) final { - ICHECK(context->target.defined()); + TVM_FFI_ICHECK(context->target.defined()); Target target = context->target.value(); ffi::Optional opt_max_threads_per_block = @@ -81,8 +81,8 @@ class CrossThreadReductionNode : public ScheduleRuleNode { ffi::Array probs(n_candidate, FloatImm(DataType::Float(32), 1.0 / n_candidate)); s_tir::ExprRV thread_extent = tmp_sch->SampleCategorical(thread_extents, probs); if (fusible) { - ICHECK(target_sblock.defined()); - ICHECK(target_loop.defined()); + TVM_FFI_ICHECK(target_sblock.defined()); + TVM_FFI_ICHECK(target_loop.defined()); // Step 3.1. // - If the outer loops of `target_sblock` haven't been bound to "threadIdx.x", we should @@ -159,8 +159,8 @@ class CrossThreadReductionNode : public ScheduleRuleNode { auto fcheck = [&](const Any& a) -> bool { return a.as() == loop.get(); }; int i = std::find_if(inst->outputs.begin(), inst->outputs.end(), fcheck) - inst->outputs.begin(); - CHECK(inst->inputs[1 + i] != nullptr) - << "ValueError: Extracting an extent which needs inference is not supported so far"; + TVM_FFI_CHECK(inst->inputs[1 + i] != nullptr, ValueError) + << "Extracting an extent which needs inference is not supported so far"; *extent = Downcast(inst->inputs[1 + i]); return true; } @@ -182,7 +182,7 @@ class CrossThreadReductionNode : public ScheduleRuleNode { } } } - CHECK(false) << "ValueError: Unable to get the extent of \"threadIdx.x\""; + TVM_FFI_CHECK(false, ValueError) << "Unable to get the extent of \"threadIdx.x\""; throw; } @@ -292,7 +292,8 @@ class CrossThreadReductionNode : public ScheduleRuleNode { ScheduleRule ScheduleRule::CrossThreadReduction(ffi::Array thread_extents) { for (const auto& extent : thread_extents) { - CHECK(extent->value > 0) << "ValueError: The candidates of thread extent must be positive"; + TVM_FFI_CHECK(extent->value > 0, ValueError) + << "The candidates of thread extent must be positive"; } ObjectPtr n = ffi::make_object(); n->thread_extents = std::move(thread_extents); diff --git a/src/s_tir/meta_schedule/schedule_rule/multi_level_tiling.cc b/src/s_tir/meta_schedule/schedule_rule/multi_level_tiling.cc index 8a3e237cdeb1..f9823e7ad656 100644 --- a/src/s_tir/meta_schedule/schedule_rule/multi_level_tiling.cc +++ b/src/s_tir/meta_schedule/schedule_rule/multi_level_tiling.cc @@ -204,7 +204,7 @@ std::vector MultiLevelTilingNode::TileLoopNest(State state, // Step 1. Assuming trivial binding, pair the loops and their iter-var-types ffi::Array loops = sch->GetLoops(block_rv); std::vector iter_types = GetSBlockVarTypes(sch->GetSRef(state->block_rv)); - ICHECK_EQ(loops.size(), iter_types.size()); + TVM_FFI_ICHECK_EQ(loops.size(), iter_types.size()); // Step 2. For each loop axis, tile it int64_t spatial_loop_product = 1; @@ -212,7 +212,7 @@ std::vector MultiLevelTilingNode::TileLoopNest(State state, std::for_each(iter_types.begin(), iter_types.end(), [&](const auto& iter_type) { if (iter_type == IterVarType::kDataPar) total_spatial_loop_num++; }); - CHECK_GE(total_spatial_loop_num, tile_inner_most_space_loop_num); + TVM_FFI_ICHECK_GE(total_spatial_loop_num, tile_inner_most_space_loop_num); if (tile_inner_most_space_loop_num < 0) tile_inner_most_space_loop_num = total_spatial_loop_num; int outer_most_spatial_loop_skipped_num = total_spatial_loop_num - tile_inner_most_space_loop_num; @@ -294,7 +294,7 @@ std::vector MultiLevelTilingNode::AddReadReuse(State state) const { if (config.req == ReuseType::kNoReuse) { return {std::move(state)}; } - ICHECK(config.req != ReuseType::kMayReuse); + TVM_FFI_ICHECK(config.req != ReuseType::kMayReuse); const SBlockRV& block_rv = state->block_rv; std::vector results; results.reserve(config.levels.size()); @@ -366,7 +366,7 @@ void MultiLevelTilingNode::AnnotateCooperativeFetching(Schedule* sch, const s_tir::SBlockRV& block) const { // Filter out invalid vector lanes according to the data type. const tir::SBlockNode* block_node = (*sch)->GetSRef(block)->StmtAs(); - ICHECK_EQ(block_node->writes.size(), 1); + TVM_FFI_ICHECK_EQ(block_node->writes.size(), 1); const runtime::DataType dtype = block_node->writes[0]->buffer->dtype; std::function f_filter = nullptr; if (dtype == runtime::DataType::Float(32)) { diff --git a/src/s_tir/meta_schedule/schedule_rule/multi_level_tiling.h b/src/s_tir/meta_schedule/schedule_rule/multi_level_tiling.h index c76ad1dfd21a..adfd3ebdb9ac 100644 --- a/src/s_tir/meta_schedule/schedule_rule/multi_level_tiling.h +++ b/src/s_tir/meta_schedule/schedule_rule/multi_level_tiling.h @@ -74,7 +74,7 @@ inline ReuseType Str2ReuseType(const ffi::String& str) { } else if (str == "must") { return ReuseType::kMustReuse; } else { - LOG(FATAL) << "ValueError: Unknown ReuseType: " << str; + TVM_FFI_THROW(ValueError) << "Unknown ReuseType: " << str; throw; } } @@ -96,7 +96,7 @@ struct ReuseConfig { : req(Str2ReuseType(Downcast(config.at("req")))), levels(support::AsVector(Downcast>(config.at("levels")))), scope(Downcast(config.at("scope"))) { - ICHECK_EQ(config.size(), 3); + TVM_FFI_ICHECK_EQ(config.size(), 3); } }; @@ -257,7 +257,7 @@ ObjectPtr MultiLevelTilingInitCommon( } else if (c == 'R') { n->r_indices_.push_back(i); } else { - LOG(FATAL) << "ValueError: Invalid tiling structure: " << structure; + TVM_FFI_THROW(ValueError) << "Invalid tiling structure: " << structure; } } n->thread_warp_size_ = -1; diff --git a/src/s_tir/meta_schedule/schedule_rule/multi_level_tiling_tensor_core.cc b/src/s_tir/meta_schedule/schedule_rule/multi_level_tiling_tensor_core.cc index 1a520cb8d3b1..b8162acdbbb7 100644 --- a/src/s_tir/meta_schedule/schedule_rule/multi_level_tiling_tensor_core.cc +++ b/src/s_tir/meta_schedule/schedule_rule/multi_level_tiling_tensor_core.cc @@ -60,7 +60,7 @@ struct TensorCoreIntrinGroup { TensorCoreIntrinGroup TensorCoreIntrinGroup::FromConfig( const ffi::Map& config) { auto f_initialize_intrin = [&config](ffi::String key_name, ffi::String* intrin_name) { - CHECK(config.count(key_name)) << "ValueError: " << key_name << " is not set."; + TVM_FFI_CHECK(config.count(key_name), ValueError) << key_name << " is not set."; *intrin_name = config.at(key_name); // Check the existence of the intrin tir::TensorIntrin::Get(*intrin_name); @@ -292,7 +292,7 @@ void MultiLevelTilingTensorCoreNode::TileAndAnnotateTensorize( Schedule* sch, const SBlockRV& block_rv, const ffi::String& intrin_name, const ffi::String& permuted_layout_annotate_value) const { ffi::Optional loop = s_tir::TileWithTensorIntrin(*sch, block_rv, intrin_name).value(); - ICHECK(loop.defined()); + TVM_FFI_ICHECK(loop.defined()); SBlockRV blockized_outer = (*sch)->Blockize(loop.value()); (*sch)->Annotate(blockized_outer, tir::attr::meta_schedule_auto_tensorize, intrin_name); if (!permuted_layout_annotate_value.empty()) { @@ -305,7 +305,7 @@ std::vector MultiLevelTilingTensorCoreNode::MMAAddReadReuse(TensorCoreSta if (config.req == ReuseType::kNoReuse) { return {std::move(state)}; } - ICHECK(config.req != ReuseType::kMayReuse); + TVM_FFI_ICHECK(config.req != ReuseType::kMayReuse); const SBlockRV& block_rv = state->block_rv; std::vector results; results.reserve(config.levels.size()); @@ -358,7 +358,7 @@ std::vector MultiLevelTilingTensorCoreNode::MMATileLoopNest(TensorCoreSta return {}; } std::vector iter_types = GetSBlockVarTypes(sch->GetSRef(state->block_rv)); - ICHECK_EQ(loops.size(), iter_types.size()); + TVM_FFI_ICHECK_EQ(loops.size(), iter_types.size()); // Step 2. For each loop axis, tile it int64_t spatial_loop_product = 1; std::vector> tiles(s_indices_.size() + r_indices_.size()); @@ -454,7 +454,7 @@ std::vector MultiLevelTilingTensorCoreNode::TransformIntermediateOutputLa // Get the tile index of the warp id (i.e. threadIdx.y) auto it = std::find(tile_binds.begin(), tile_binds.end(), "threadIdx.y"); - ICHECK(it != tile_binds.end()); + TVM_FFI_ICHECK(it != tile_binds.end()); auto tile_index_warp_id = std::distance(tile_binds.begin(), it); // Get the extent of loop indicated by `loop_idx` inside the warp scope. @@ -471,7 +471,7 @@ std::vector MultiLevelTilingTensorCoreNode::TransformIntermediateOutputLa } factors.push_back(s_factors[loop_idx]); } - ICHECK(!factors.empty()); + TVM_FFI_ICHECK(!factors.empty()); if (factors.size() == 1) { return factors[0]; } @@ -489,7 +489,7 @@ std::vector MultiLevelTilingTensorCoreNode::TransformIntermediateOutputLa Schedule& sch = state->sch; int buffer_ndim = static_cast(sch->Get(state->block_rv)->writes[0]->buffer->shape.size()); // The dimension of the buffer should be larger or same as that of the tensor intrin. - ICHECK_GE(buffer_ndim, 2); + TVM_FFI_ICHECK_GE(buffer_ndim, 2); int num_higher_dims = buffer_ndim - 2; auto index_map = @@ -566,7 +566,7 @@ std::vector MultiLevelTilingTensorCoreNode::AddWriteReuseTensorCore( // Get the loops other than the innermost two loops (accum_m and accum_n). auto f_get_loops = [&](const SBlockRV& block_rv) -> std::array { ffi::Array buffer_loops = sch->GetLoops(block_rv); - ICHECK_GT(buffer_loops.size(), 6); + TVM_FFI_ICHECK_GT(buffer_loops.size(), 6); return {buffer_loops[buffer_loops.size() - 6], buffer_loops[buffer_loops.size() - 5], buffer_loops[buffer_loops.size() - 4], buffer_loops[buffer_loops.size() - 3]}; }; @@ -590,7 +590,7 @@ std::vector MultiLevelTilingTensorCoreNode::AddWriteReuseTensorCore( state->intrin_group.store_intrin); ffi::Array buffer_loops = sch->GetLoops(state->write_reuse[0]); - ICHECK_GT(buffer_loops.size(), 5); + TVM_FFI_ICHECK_GT(buffer_loops.size(), 5); sch->Fuse(ffi::Array{buffer_loops.end() - 5, // The src shmem is always 2D buffer_loops.end()}); AnnotateCooperativeFetching(&sch, state->write_reuse[0]); @@ -601,7 +601,8 @@ std::vector MultiLevelTilingTensorCoreNode::AddReadReuseTensorCore( TensorCoreState state) const { const ffi::Array& r_tiles = state->tiles[r_indices_[1]]; Schedule& sch = state->sch; - ICHECK(!r_tiles.empty()) << "ValueError: Cannot find the suitable reduction loop in the block"; + TVM_FFI_CHECK(!r_tiles.empty(), ValueError) + << "Cannot find the suitable reduction loop in the block"; auto f_tensorize_load = [&](int read_index, ffi::String scope, ffi::String intrin_name) { auto cache_read = sch->CacheRead(state->block_rv, read_index, scope); @@ -653,7 +654,7 @@ std::vector MultiLevelTilingTensorCoreNode::AddSoftwarePipeline( const ffi::Array& tiles = state->tiles[r_index]; for (const LoopRV& tile : tiles) { const auto* extent = sch->Get(tile)->extent.as(); - ICHECK(extent != nullptr) << "Dynamic extent is not supported."; + TVM_FFI_ICHECK(extent != nullptr) << "Dynamic extent is not supported."; reduction_length *= extent->value; } } @@ -789,14 +790,14 @@ ffi::Optional MultiLevelTilingTensorCoreNode::TransformWithTensorIntrin( // Transform the layout of reindex buffers accordingly. // The index map defines the mapping for the computation block. We need to extract the sub index // map to transform the load and store block. - ICHECK_EQ(mapping_info->mappings.size(), 1U); // assume only one mapping is present + TVM_FFI_ICHECK_EQ(mapping_info->mappings.size(), 1U); // assume only one mapping is present const tir::IndexMap& index_map = mapping_info->mappings[0]; // Find the correspondence between block iters and the iters in the index map. std::unordered_map lhs_to_index_map_src; std::unordered_map rhs_to_index_map_tgt; std::unordered_set unmapped_index_map_src; - ICHECK_EQ(mapping_info->lhs_iters.size(), index_map->initial_indices.size()); + TVM_FFI_ICHECK_EQ(mapping_info->lhs_iters.size(), index_map->initial_indices.size()); for (int i = 0; i < static_cast(mapping_info->lhs_iters.size()); ++i) { lhs_to_index_map_src[mapping_info->lhs_iters[i]->var] = index_map->initial_indices[i]; } @@ -807,10 +808,10 @@ ffi::Optional MultiLevelTilingTensorCoreNode::TransformWithTensorIntrin( // kept as a outer loop after tensorization. int offset = static_cast(index_map->final_indices.size()) - static_cast(mapping_info->rhs_iters.size()); - ICHECK_GE(offset, 0); + TVM_FFI_ICHECK_GE(offset, 0); for (int i = 0; i < offset; ++i) { const tir::VarNode* var_ptr = index_map->final_indices[i].as(); - ICHECK(var_ptr != nullptr); + TVM_FFI_ICHECK(var_ptr != nullptr); unmapped_index_map_src.insert(ffi::GetRef(var_ptr)); } for (int i = offset; i < static_cast(index_map->final_indices.size()); ++i) { @@ -822,9 +823,9 @@ ffi::Optional MultiLevelTilingTensorCoreNode::TransformWithTensorIntrin( std::vector sub_index_map_tgt; const tir::Buffer& rhs_buffer = mapping_info->lhs_buffer_map[lhs_buffer]; for (const Range& range : lhs_region) { - ICHECK(tir::is_one(range->extent)); + TVM_FFI_ICHECK(tir::is_one(range->extent)); const tir::VarNode* var_ptr = range->min.as(); - ICHECK(var_ptr != nullptr); + TVM_FFI_ICHECK(var_ptr != nullptr); const tir::Var& lhs_representer = lhs_to_index_map_src[ffi::GetRef(var_ptr)]; sub_index_map_src.push_back(lhs_representer); if (unmapped_index_map_src.count(lhs_representer)) { @@ -833,7 +834,7 @@ ffi::Optional MultiLevelTilingTensorCoreNode::TransformWithTensorIntrin( } for (size_t i = 0; i < mapping_info->rhs_buffer_indices[rhs_buffer].size(); ++i) { const tir::VarNode* var = mapping_info->rhs_buffer_indices[rhs_buffer][i].as(); - ICHECK(var != nullptr); + TVM_FFI_ICHECK(var != nullptr); sub_index_map_tgt.push_back(rhs_to_index_map_tgt[ffi::GetRef(var)]); } return tir::IndexMap(sub_index_map_src, sub_index_map_tgt); @@ -915,7 +916,8 @@ ScheduleRule ScheduleRule::MultiLevelTilingTensorCore( ffi::Optional> reuse_write, bool use_software_pipeline) { if (tile_binds.defined()) { for (const ffi::String& tile_bind : tile_binds.value()) { - CHECK_NE(tile_bind, "threadIdx.x") << "Cannot bind to threadIdx.x when using tensor core."; + TVM_FFI_ICHECK_NE(tile_bind, "threadIdx.x") + << "Cannot bind to threadIdx.x when using tensor core."; } } auto node = MultiLevelTilingInitCommon( @@ -932,10 +934,11 @@ ScheduleRule ScheduleRule::MultiLevelTilingTensorCore( } if (have_wmma_intrin_group) { - CHECK(node->reuse_write_.req == ReuseType::kMustReuse && - runtime::StorageScope::Create(node->reuse_write_.scope).rank == - runtime::StorageRank::kShared) - << "ValueError: Shared memory write reuse must be enabled for MultiLevelTilingTensorCore."; + TVM_FFI_CHECK(node->reuse_write_.req == ReuseType::kMustReuse && + runtime::StorageScope::Create(node->reuse_write_.scope).rank == + runtime::StorageRank::kShared, + ValueError) + << "Shared memory write reuse must be enabled for MultiLevelTilingTensorCore."; } node->use_software_pipeline = use_software_pipeline; diff --git a/src/s_tir/meta_schedule/schedule_rule/multi_level_tiling_wide_vector.cc b/src/s_tir/meta_schedule/schedule_rule/multi_level_tiling_wide_vector.cc index 3625f0017092..5acc528e7d40 100644 --- a/src/s_tir/meta_schedule/schedule_rule/multi_level_tiling_wide_vector.cc +++ b/src/s_tir/meta_schedule/schedule_rule/multi_level_tiling_wide_vector.cc @@ -69,7 +69,7 @@ MultiLevelTilingWideVectorNode::SplitLoop(const Schedule& sch, SBlockRV block_rv const tir::StmtSRef block_sref = sch->GetSRef(block_rv); const tir::SBlockNode* block_node = block_sref->StmtAs(); const tir::SBlockRealize block_realize = s_tir::GetSBlockRealize(sch->state(), block_sref); - ICHECK(block_node && block_node->writes.size() == 1); + TVM_FFI_ICHECK(block_node && block_node->writes.size() == 1); const auto out_dtype = block_node->writes[0]->buffer->dtype; const int vec_len = vector_length_in_bits / out_dtype.bits(); diff --git a/src/s_tir/meta_schedule/schedule_rule/multi_level_tiling_with_intrin.cc b/src/s_tir/meta_schedule/schedule_rule/multi_level_tiling_with_intrin.cc index 3c1b7a7e3ce2..ed1ffceb6fe6 100644 --- a/src/s_tir/meta_schedule/schedule_rule/multi_level_tiling_with_intrin.cc +++ b/src/s_tir/meta_schedule/schedule_rule/multi_level_tiling_with_intrin.cc @@ -38,7 +38,7 @@ ffi::Optional TileForIntrin(s_tir::Schedule sch, s_tir::SBlockR if (!tiled_loop_rv) { return std::nullopt; } - ICHECK(tiled_loop_rv.defined()); + TVM_FFI_ICHECK(tiled_loop_rv.defined()); s_tir::SBlockRV outer_block = sch->Blockize(tiled_loop_rv.value()); sch->Annotate(outer_block, tir::attr::meta_schedule_auto_tensorize, ffi::String(intrin_name)); return outer_block; @@ -105,7 +105,7 @@ ScheduleRule ScheduleRule::MultiLevelTilingWithIntrin( ffi::Optional> vector_load_lens, ffi::Optional> reuse_read, ffi::Optional> reuse_write) { - ICHECK(tir::TensorIntrin::Get(intrin_name).defined()) + TVM_FFI_ICHECK(tir::TensorIntrin::Get(intrin_name).defined()) << "Provided tensor intrinsic " << intrin_name << " is not registered."; auto node = MultiLevelTilingInitCommon( structure, tile_binds, max_innermost_factor, vector_load_lens, reuse_read, reuse_write); diff --git a/src/s_tir/meta_schedule/schedule_rule/parallel_vectorize_unroll.cc b/src/s_tir/meta_schedule/schedule_rule/parallel_vectorize_unroll.cc index 384b812164cd..036eec6bc250 100644 --- a/src/s_tir/meta_schedule/schedule_rule/parallel_vectorize_unroll.cc +++ b/src/s_tir/meta_schedule/schedule_rule/parallel_vectorize_unroll.cc @@ -45,7 +45,7 @@ class ParallelizeVectorizeUnrollNode : public ScheduleRuleNode { public: // Inherited from ScheduleRuleNode void InitializeWithTuneContext(const TuneContext& context) final { - ICHECK(context->target.defined()); + TVM_FFI_ICHECK(context->target.defined()); if (this->max_jobs_per_core != -1) { Target target = context->target.value(); this->max_parallel_extent_ = GetTargetNumCores(target) * max_jobs_per_core; diff --git a/src/s_tir/meta_schedule/schedule_rule/random_compute_location.cc b/src/s_tir/meta_schedule/schedule_rule/random_compute_location.cc index b03012f02f3b..4fb9034d1aef 100644 --- a/src/s_tir/meta_schedule/schedule_rule/random_compute_location.cc +++ b/src/s_tir/meta_schedule/schedule_rule/random_compute_location.cc @@ -47,7 +47,7 @@ class RandomComputeLocationNode : public ScheduleRuleNode { true)) { producers = sch->GetProducers(block_rv); sch->Unannotate(block_rv, tir::attr::meta_schedule_random_compute_producer); - ICHECK_EQ(producers.size(), 1); + TVM_FFI_ICHECK_EQ(producers.size(), 1); } // Step 2. Transform the input block. diff --git a/src/s_tir/meta_schedule/schedule_rule/schedule_rule.cc b/src/s_tir/meta_schedule/schedule_rule/schedule_rule.cc index 2b539d9574e5..f237498726c4 100644 --- a/src/s_tir/meta_schedule/schedule_rule/schedule_rule.cc +++ b/src/s_tir/meta_schedule/schedule_rule/schedule_rule.cc @@ -26,19 +26,19 @@ namespace s_tir { namespace meta_schedule { void PyScheduleRuleNode::InitializeWithTuneContext(const TuneContext& context) { - ICHECK(f_initialize_with_tune_context != nullptr) + TVM_FFI_ICHECK(f_initialize_with_tune_context != nullptr) << "PyScheduleRule's InitializeWithTuneContext method not implemented!"; f_initialize_with_tune_context(context); } ffi::Array PyScheduleRuleNode::Apply(const s_tir::Schedule& sch, const s_tir::SBlockRV& block) { - ICHECK(f_apply != nullptr) << "PyScheduleRule's Apply method not implemented!"; + TVM_FFI_ICHECK(f_apply != nullptr) << "PyScheduleRule's Apply method not implemented!"; return f_apply(sch, block); } ScheduleRule PyScheduleRuleNode::Clone() const { - ICHECK(f_clone != nullptr) << "PyScheduleRule's Clone method not implemented!"; + TVM_FFI_ICHECK(f_clone != nullptr) << "PyScheduleRule's Clone method not implemented!"; return f_clone(); } @@ -454,9 +454,9 @@ ffi::Array ScheduleRule::DefaultARM(const ffi::String& type) { TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& n, ReprPrinter* p) { const auto* self = n.as(); - ICHECK(self); + TVM_FFI_ICHECK(self); PyScheduleRuleNode::FAsString f_as_string = (*self).f_as_string; - ICHECK(f_as_string != nullptr) << "PyScheduleRule's AsString method not implemented!"; + TVM_FFI_ICHECK(f_as_string != nullptr) << "PyScheduleRule's AsString method not implemented!"; p->stream << f_as_string(); }); diff --git a/src/s_tir/meta_schedule/search_strategy/evolutionary_search.cc b/src/s_tir/meta_schedule/search_strategy/evolutionary_search.cc index 90c055c75f05..fabe50dd60f9 100644 --- a/src/s_tir/meta_schedule/search_strategy/evolutionary_search.cc +++ b/src/s_tir/meta_schedule/search_strategy/evolutionary_search.cc @@ -24,9 +24,10 @@ #include "../module_equality.h" #include "../utils.h" -#define TVM_META_SCHEDULE_CHECK_PROB_RANGE(p, name) \ - CHECK(0.0 <= (p) && (p) <= 1.0) << "ValueError: name should be within [0, 1], " \ - << "but get `" << #p << " = " << (p) << '\''; +#define TVM_META_SCHEDULE_CHECK_PROB_RANGE(p, name) \ + TVM_FFI_CHECK(0.0 <= (p) && (p) <= 1.0, ValueError) \ + << "name should be within [0, 1], " \ + << "but get `" << #p << " = " << (p) << '\''; namespace tvm { namespace s_tir { @@ -236,7 +237,8 @@ std::vector PredictNormalizedScore(const std::vector& candidat const TuneContext& context, const CostModel& cost_model) { auto _ = Profiler::TimedScope("EvoSearch/Evolve/PredictNormalizedScore"); - ICHECK(!candidates.empty()) << "Candidates given for score prediction can not be empty list!"; + TVM_FFI_ICHECK(!candidates.empty()) + << "Candidates given for score prediction can not be empty list!"; std::vector scores = cost_model->Predict(context, AssembleCandidates(candidates)); for (double& score : scores) { score = std::max(0.0, score); @@ -402,13 +404,13 @@ class EvolutionarySearchNode : public SearchStrategyNode { EvolutionarySearchNode, SearchStrategyNode); void InitializeWithTuneContext(const TuneContext& ctx) final { - CHECK(ctx->num_threads > 0) << "ValueError: `TuneContext.num_threads` must be > 0"; - CHECK(ctx->space_generator.defined()) - << "ValueError: `TuneContext.space_generator` must be defined"; - CHECK(ctx->space_generator.value()->postprocs.defined()) - << "ValueError: `TuneContext.space_generator.postprocs` must be defined"; - CHECK(ctx->space_generator.value()->mutator_probs.defined()) - << "ValueError: `TuneContext.space_generator.mutator_probs` must be defined"; + TVM_FFI_CHECK(ctx->num_threads > 0, ValueError) << "`TuneContext.num_threads` must be > 0"; + TVM_FFI_CHECK(ctx->space_generator.defined(), ValueError) + << "`TuneContext.space_generator` must be defined"; + TVM_FFI_CHECK(ctx->space_generator.value()->postprocs.defined(), ValueError) + << "`TuneContext.space_generator.postprocs` must be defined"; + TVM_FFI_CHECK(ctx->space_generator.value()->mutator_probs.defined(), ValueError) + << "`TuneContext.space_generator.mutator_probs` must be defined"; this->ctx_ = ctx.get(); this->postprocs_ = ctx->space_generator.value()->postprocs.value(); this->mutator_probs_ = ctx->space_generator.value()->mutator_probs.value(); @@ -419,38 +421,41 @@ class EvolutionarySearchNode : public SearchStrategyNode { void PreTuning(int max_trials, int num_trials_per_iter, const ffi::Array& design_spaces, const ffi::Optional& database, const ffi::Optional& cost_model) final { - ICHECK(!design_spaces.empty()); - CHECK(this->ctx_ != nullptr) << "ValueError: Did you forget to initialize the TuneContext?"; - CHECK(database.defined()) << "ValueError: Database is not supplied in PreTuning. Evolutionary" - "search algorithm requires a database to be present, so that it " - "could sample from previously-explored population. If you do not " - "intent to store data on disk, please use " - "`tvm.s_tir.meta_schedule.database.MemoryDatabase`"; - CHECK(cost_model.defined()) - << "ValueError: CostModel is not supplied in PreTuning. Evolutionary search " + TVM_FFI_ICHECK(!design_spaces.empty()); + TVM_FFI_CHECK(this->ctx_ != nullptr, ValueError) + << "Did you forget to initialize the TuneContext?"; + TVM_FFI_CHECK(database.defined(), ValueError) + << "Database is not supplied in PreTuning. Evolutionary" + "search algorithm requires a database to be present, so that it " + "could sample from previously-explored population. If you do not " + "intent to store data on disk, please use " + "`tvm.s_tir.meta_schedule.database.MemoryDatabase`"; + TVM_FFI_CHECK(cost_model.defined(), ValueError) + << "CostModel is not supplied in PreTuning. Evolutionary search " "algorithm expects a cost model to filter out potentially less efficient kernels. If " "you do not expect a cost model to help, please use " "`tvm.s_tir.meta_schedule.cost_model.RandomModel`"; - CHECK(this->state_ == nullptr) - << "ValueError: `PreTuning` is already invoked without corresponding `PostTuning`."; + TVM_FFI_CHECK(this->state_ == nullptr, ValueError) + << "`PreTuning` is already invoked without corresponding `PostTuning`."; this->state_ = std::make_unique(this, max_trials, num_trials_per_iter, design_spaces, database.value(), cost_model.value()); } void PostTuning() final { - CHECK(this->state_ != nullptr) << "ValueError: `PostTuning` is invoked without corresponding " - "`PreTuning`, or `PostTuning` is already invoked."; + TVM_FFI_CHECK(this->state_ != nullptr, ValueError) + << "`PostTuning` is invoked without corresponding " + "`PreTuning`, or `PostTuning` is already invoked."; this->state_.reset(); } ffi::Optional> GenerateMeasureCandidates() final { - ICHECK(this->state_ != nullptr); + TVM_FFI_ICHECK(this->state_ != nullptr); return this->state_->GenerateMeasureCandidates(); } void NotifyRunnerResults(const ffi::Array& measure_candidates, const ffi::Array& results) final { - ICHECK(this->state_ != nullptr); + TVM_FFI_ICHECK(this->state_ != nullptr); this->state_->NotifyRunnerResults(measure_candidates, results); } @@ -490,11 +495,11 @@ std::vector EvolutionarySearchNode::State::PickBestFromDatabase(int nu const IRModule& mod = data.mod; s_tir::Trace trace = measured_traces.at(trace_id); Schedule& result = results.at(trace_id); - ICHECK(!result.defined()); + TVM_FFI_ICHECK(!result.defined()); if (ffi::Optional sch = pp.Apply(mod, trace, rand_state)) { result = sch.value(); } else { - LOG(FATAL) << "ValueError: Cannot postprocess the trace:\n" << trace; + TVM_FFI_THROW(ValueError) << "Cannot postprocess the trace:\n" << trace; throw; } }; @@ -515,7 +520,7 @@ std::vector EvolutionarySearchNode::State::SampleInitPopulation(int nu TRandState* rand_state = &data.rand_state; const IRModule& mod = data.mod; Schedule& result = results.at(trace_id); - ICHECK(!result.defined()); + TVM_FFI_ICHECK(!result.defined()); int design_space_index = s_tir::SampleInt(rand_state, 0, design_spaces.size()); s_tir::Trace trace(design_spaces[design_space_index]->insts, {}); if (ffi::Optional sch = pp.Apply(mod, trace, rand_state)) { @@ -542,7 +547,7 @@ std::vector EvolutionarySearchNode::State::EvolveWithCostModel( IRModuleSet exists(database_->GetModuleEquality()); { auto _ = Profiler::TimedScope("EvoSearch/Evolve/Misc/CopyMeasuredWorkloads"); - ICHECK_GT(num, 0); + TVM_FFI_ICHECK_GT(num, 0); // The heap to record best schedule, we do not consider schedules that are already measured exists = this->measured_workloads_; } @@ -554,7 +559,7 @@ std::vector EvolutionarySearchNode::State::EvolveWithCostModel( { auto _ = Profiler::TimedScope("EvoSearch/Evolve/Misc"); - ICHECK_EQ(scores.size(), population.size()); + TVM_FFI_ICHECK_EQ(scores.size(), population.size()); for (int i = 0, n = population.size(); i < n; ++i) { Schedule sch = population.at(i); IRModule mod = sch->mod(); @@ -708,7 +713,7 @@ EvolutionarySearchNode::State::GenerateMeasureCandidates() { sample_num = max_trials - st; ed = max_trials; } - ICHECK_LT(st, ed); + TVM_FFI_ICHECK_LT(st, ed); int pop = self->population_size; std::vector inits; inits.reserve(pop); diff --git a/src/s_tir/meta_schedule/search_strategy/replay_func.cc b/src/s_tir/meta_schedule/search_strategy/replay_func.cc index 2eee5bafdff5..aaa34ae2ca4e 100644 --- a/src/s_tir/meta_schedule/search_strategy/replay_func.cc +++ b/src/s_tir/meta_schedule/search_strategy/replay_func.cc @@ -46,8 +46,8 @@ class ReplayFuncNode : public SearchStrategyNode { num_trials_per_iter(num_trials_per_iter), st(0), ed(num_trials_per_iter) { - CHECK(self->mod_.defined() && self->space_generator_.defined()) - << "ValueError: The search strategy has not been initialized."; + TVM_FFI_CHECK(self->mod_.defined() && self->space_generator_.defined(), ValueError) + << "The search strategy has not been initialized."; } inline ffi::Optional> GenerateMeasureCandidates(); @@ -71,9 +71,9 @@ class ReplayFuncNode : public SearchStrategyNode { SearchStrategyNode); void InitializeWithTuneContext(const TuneContext& ctx) final { - CHECK(ctx->mod.defined()) << "ValueError: TuneContext.mod is not defined"; - CHECK(ctx->space_generator.defined()) - << "ValueError: TuneContext.space_generator is not defined"; + TVM_FFI_CHECK(ctx->mod.defined(), ValueError) << "TuneContext.mod is not defined"; + TVM_FFI_CHECK(ctx->space_generator.defined(), ValueError) + << "TuneContext.space_generator is not defined"; if (!ctx->space_generator.value()->postprocs.defined()) { TVM_PY_LOG(WARNING, ctx->logger) << "`postprocs` is not defined in " << ctx->space_generator.value() @@ -90,25 +90,26 @@ class ReplayFuncNode : public SearchStrategyNode { const ffi::Array& design_spaces, const ffi::Optional& database, const ffi::Optional& cost_model) final { - CHECK(this->state_ == nullptr) - << "ValueError: `PreTuning` is already invoked without corresponding `PostTuning`."; + TVM_FFI_CHECK(this->state_ == nullptr, ValueError) + << "`PreTuning` is already invoked without corresponding `PostTuning`."; this->state_ = std::make_unique(this, max_trials, num_trials_per_iter); } void PostTuning() final { - CHECK(this->state_ != nullptr) << "ValueError: `PostTuning` is invoked without corresponding " - "`PreTuning`, or `PostTuning` is already invoked."; + TVM_FFI_CHECK(this->state_ != nullptr, ValueError) + << "`PostTuning` is invoked without corresponding " + "`PreTuning`, or `PostTuning` is already invoked."; this->state_.reset(); } ffi::Optional> GenerateMeasureCandidates() final { - ICHECK(this->state_ != nullptr); + TVM_FFI_ICHECK(this->state_ != nullptr); return this->state_->GenerateMeasureCandidates(); } void NotifyRunnerResults(const ffi::Array& measure_candidates, const ffi::Array& results) final { - ICHECK(this->state_ != nullptr); + TVM_FFI_ICHECK(this->state_ != nullptr); this->state_->NotifyRunnerResults(results); } diff --git a/src/s_tir/meta_schedule/search_strategy/replay_trace.cc b/src/s_tir/meta_schedule/search_strategy/replay_trace.cc index 976c80e87500..9f15fd492d14 100644 --- a/src/s_tir/meta_schedule/search_strategy/replay_trace.cc +++ b/src/s_tir/meta_schedule/search_strategy/replay_trace.cc @@ -86,9 +86,9 @@ class ReplayTraceNode : public SearchStrategyNode { SearchStrategyNode); void InitializeWithTuneContext(const TuneContext& ctx) final { - CHECK(ctx->mod.defined()) << "ValueError: TuneContext.mod is not defined"; - CHECK(ctx->space_generator.defined()) - << "ValueError: TuneContext.space_generator is not defined"; + TVM_FFI_CHECK(ctx->mod.defined(), ValueError) << "TuneContext.mod is not defined"; + TVM_FFI_CHECK(ctx->space_generator.defined(), ValueError) + << "TuneContext.space_generator is not defined"; if (!ctx->space_generator.value()->postprocs.defined()) { TVM_PY_LOG(WARNING, ctx->logger) << "`postprocs` is not defined in " << ctx->space_generator.value() @@ -106,9 +106,9 @@ class ReplayTraceNode : public SearchStrategyNode { const ffi::Array& design_spaces, const ffi::Optional& database, const ffi::Optional& cost_model) final { - ICHECK(!design_spaces.empty()); - CHECK(this->state_ == nullptr) - << "ValueError: `PreTuning` is already invoked without corresponding `PostTuning`."; + TVM_FFI_ICHECK(!design_spaces.empty()); + TVM_FFI_CHECK(this->state_ == nullptr, ValueError) + << "`PreTuning` is already invoked without corresponding `PostTuning`."; ffi::Array design_space_traces; design_space_traces.reserve(design_spaces.size()); for (const s_tir::Schedule& space : design_spaces) { @@ -119,18 +119,18 @@ class ReplayTraceNode : public SearchStrategyNode { } void PostTuning() final { - ICHECK(this->state_ != nullptr); + TVM_FFI_ICHECK(this->state_ != nullptr); this->state_.reset(); } ffi::Optional> GenerateMeasureCandidates() final { - ICHECK(this->state_ != nullptr); + TVM_FFI_ICHECK(this->state_ != nullptr); return this->state_->GenerateMeasureCandidates(); } void NotifyRunnerResults(const ffi::Array& measure_candidates, const ffi::Array& results) final { - ICHECK(this->state_ != nullptr); + TVM_FFI_ICHECK(this->state_ != nullptr); this->state_->NotifyRunnerResults(results); } @@ -149,7 +149,7 @@ ReplayTraceNode::State::GenerateMeasureCandidates() { return std::nullopt; } ed = std::min(ed, max_trials); - ICHECK_LT(st, ed); + TVM_FFI_ICHECK_LT(st, ed); std::vector per_thread_rand_state = ForkSeed(&self->rand_state_, self->num_threads_); ffi::Array> per_task_result(ed - st, std::nullopt); ThreadedTraceApply pp(self->postprocs_); diff --git a/src/s_tir/meta_schedule/search_strategy/search_strategy.cc b/src/s_tir/meta_schedule/search_strategy/search_strategy.cc index db12b213c66b..3237ee51b2e1 100644 --- a/src/s_tir/meta_schedule/search_strategy/search_strategy.cc +++ b/src/s_tir/meta_schedule/search_strategy/search_strategy.cc @@ -32,7 +32,7 @@ MeasureCandidate::MeasureCandidate(s_tir::Schedule sch, ffi::Array args } void PySearchStrategyNode::InitializeWithTuneContext(const TuneContext& context) { - ICHECK(f_initialize_with_tune_context != nullptr) + TVM_FFI_ICHECK(f_initialize_with_tune_context != nullptr) << "PySearchStrategy's InitializeWithTuneContext method not implemented!"; f_initialize_with_tune_context(context); } @@ -41,17 +41,18 @@ void PySearchStrategyNode::PreTuning(int max_trials, int num_trials_per_iter, const ffi::Array& design_spaces, const ffi::Optional& database, const ffi::Optional& cost_model) { - ICHECK(f_pre_tuning != nullptr) << "PySearchStrategy's PreTuning method not implemented!"; + TVM_FFI_ICHECK(f_pre_tuning != nullptr) << "PySearchStrategy's PreTuning method not implemented!"; f_pre_tuning(max_trials, num_trials_per_iter, design_spaces, database, cost_model); } void PySearchStrategyNode::PostTuning() { - ICHECK(f_post_tuning != nullptr) << "PySearchStrategy's PostTuning method not implemented!"; + TVM_FFI_ICHECK(f_post_tuning != nullptr) + << "PySearchStrategy's PostTuning method not implemented!"; f_post_tuning(); } ffi::Optional> PySearchStrategyNode::GenerateMeasureCandidates() { - ICHECK(f_generate_measure_candidates != nullptr) + TVM_FFI_ICHECK(f_generate_measure_candidates != nullptr) << "PySearchStrategy's GenerateMeasureCandidates method not implemented!"; return f_generate_measure_candidates(); } @@ -59,13 +60,13 @@ ffi::Optional> PySearchStrategyNode::GenerateMeasur void PySearchStrategyNode::NotifyRunnerResults( const ffi::Array& measure_candidates, const ffi::Array& results) { - ICHECK(f_notify_runner_results != nullptr) + TVM_FFI_ICHECK(f_notify_runner_results != nullptr) << "PySearchStrategy's NotifyRunnerResults method not implemented!"; f_notify_runner_results(measure_candidates, results); } SearchStrategy PySearchStrategyNode::Clone() const { - ICHECK(f_clone != nullptr) << "PySearchStrategy's Clone method not implemented!"; + TVM_FFI_ICHECK(f_clone != nullptr) << "PySearchStrategy's Clone method not implemented!"; return f_clone(); } diff --git a/src/s_tir/meta_schedule/space_generator/post_order_apply.cc b/src/s_tir/meta_schedule/space_generator/post_order_apply.cc index 677bc89671aa..b5a6b855b975 100644 --- a/src/s_tir/meta_schedule/space_generator/post_order_apply.cc +++ b/src/s_tir/meta_schedule/space_generator/post_order_apply.cc @@ -49,7 +49,7 @@ class PostOrderApplyNode : public SpaceGeneratorNode { ffi::Array GenerateDesignSpace(const IRModule& mod) final { using ScheduleAndUnvisitedBlocks = std::pair>; - CHECK(sch_rules.defined()) << "ValueError: `sch_rules` is not set in PostOrderApply"; + TVM_FFI_CHECK(sch_rules.defined(), ValueError) << "`sch_rules` is not set in PostOrderApply"; s_tir::Schedule sch = s_tir::Schedule::Traced( /*mod=*/mod, /*rand_state=*/ForkSeed(&this->rand_state_), diff --git a/src/s_tir/meta_schedule/space_generator/schedule_fn.cc b/src/s_tir/meta_schedule/space_generator/schedule_fn.cc index 39b4fc558346..b5b53df90282 100644 --- a/src/s_tir/meta_schedule/space_generator/schedule_fn.cc +++ b/src/s_tir/meta_schedule/space_generator/schedule_fn.cc @@ -65,16 +65,16 @@ class ScheduleFnNode : public SpaceGeneratorNode { if (auto sch = val.as()) { result.push_back(sch.value()); } else { - LOG(FATAL) << "TypeError: Expect return type of ScheduleFn to be None, Schedule or " - "List[Schedule], but got: " - << obj->GetTypeKey(); + TVM_FFI_THROW(TypeError) << "Expect return type of ScheduleFn to be None, Schedule or " + "List[Schedule], but got: " + << obj->GetTypeKey(); } } return result; } - LOG(FATAL) << "TypeError: Expect return type of ScheduleFn to be None, Schedule or " - "List[Schedule], but got: " - << obj->GetTypeKey(); + TVM_FFI_THROW(TypeError) << "Expect return type of ScheduleFn to be None, Schedule or " + "List[Schedule], but got: " + << obj->GetTypeKey(); throw; } diff --git a/src/s_tir/meta_schedule/space_generator/space_generator.cc b/src/s_tir/meta_schedule/space_generator/space_generator.cc index e3057a333b5f..3e519f62e69c 100644 --- a/src/s_tir/meta_schedule/space_generator/space_generator.cc +++ b/src/s_tir/meta_schedule/space_generator/space_generator.cc @@ -84,7 +84,7 @@ ffi::String GetRuleKindFromTarget(const Target& target) { if (target->kind->name == "c") { return "c"; } - LOG(FATAL) << "Unsupported target: " << target; + TVM_FFI_THROW(InternalError) << "Unsupported target: " << target; throw; } @@ -138,7 +138,7 @@ void SpaceGeneratorNode::InitializeWithTuneContext(const TuneContext& context) { default_postprocs = Postproc::DefaultCPUTensorization(); default_mutator_probs = Mutator::DefaultLLVM(); } else { - LOG(FATAL) << "Unsupported kind: " << kind; + TVM_FFI_THROW(InternalError) << "Unsupported kind: " << kind; throw; } if (!sch_rules.defined()) { @@ -170,19 +170,19 @@ void SpaceGeneratorNode::InitializeWithTuneContext(const TuneContext& context) { } void PySpaceGeneratorNode::InitializeWithTuneContext(const TuneContext& context) { - ICHECK(f_initialize_with_tune_context != nullptr) + TVM_FFI_ICHECK(f_initialize_with_tune_context != nullptr) << "PySpaceGenerator's InitializeWithTuneContext method not implemented!"; f_initialize_with_tune_context(context); } ffi::Array PySpaceGeneratorNode::GenerateDesignSpace(const IRModule& mod) { - ICHECK(f_generate_design_space != nullptr) + TVM_FFI_ICHECK(f_generate_design_space != nullptr) << "PySpaceGenerator's GenerateDesignSpace method not implemented!"; return f_generate_design_space(mod); } SpaceGenerator PySpaceGeneratorNode::Clone() const { - ICHECK(f_clone != nullptr) << "PySpaceGenerator's Clone method not implemented!"; + TVM_FFI_ICHECK(f_clone != nullptr) << "PySpaceGenerator's Clone method not implemented!"; return f_clone(); } diff --git a/src/s_tir/meta_schedule/task_scheduler/task_scheduler.cc b/src/s_tir/meta_schedule/task_scheduler/task_scheduler.cc index f06c5a36a7a9..906f695a1c1a 100644 --- a/src/s_tir/meta_schedule/task_scheduler/task_scheduler.cc +++ b/src/s_tir/meta_schedule/task_scheduler/task_scheduler.cc @@ -37,11 +37,11 @@ TaskRecord::TaskRecord(TuneContext ctx, double task_weight) { n->task_weight = task_weight; n->flop = 1.0; auto _ = Profiler::TimedScope("InitializeTask"); - CHECK(ctx->mod.defined()) << "ValueError: Require `context.mod`, but it is not defined"; - CHECK(ctx->space_generator.defined()) - << "ValueError: Require `context.space_generator`, but it is not defined"; - CHECK(ctx->search_strategy.defined()) - << "ValueError: Require `context.search_strategy`, but it is not defined"; + TVM_FFI_CHECK(ctx->mod.defined(), ValueError) << "Require `context.mod`, but it is not defined"; + TVM_FFI_CHECK(ctx->space_generator.defined(), ValueError) + << "Require `context.space_generator`, but it is not defined"; + TVM_FFI_CHECK(ctx->search_strategy.defined(), ValueError) + << "Require `context.search_strategy`, but it is not defined"; TVM_PY_LOG(INFO, ctx->logger) << "\n" << ctx->mod; ctx->Initialize(); n->flop = std::max(1.0, s_tir::EstimateTIRFlops(ctx->mod.value())); @@ -65,7 +65,7 @@ void SendToRunner(TaskRecordNode* self, const Runner& runner) { ffi::Array candidates = self->measure_candidates.value(); ffi::Array builder_results = self->builder_results.value(); Target target = self->ctx->target.value(); - ICHECK_EQ(candidates.size(), builder_results.size()); + TVM_FFI_ICHECK_EQ(candidates.size(), builder_results.size()); int n = candidates.size(); int n_build_errors = 0; ffi::Array inputs; @@ -105,8 +105,8 @@ void SendToRunner(TaskRecordNode* self, const Runner& runner) { } void TaskCleanUp(TaskRecordNode* self, int task_id, const ffi::Array& results) { - ICHECK_EQ(self->builder_results.value().size(), results.size()); - ICHECK_EQ(self->runner_futures.value().size(), results.size()); + TVM_FFI_ICHECK_EQ(self->builder_results.value().size(), results.size()); + TVM_FFI_ICHECK_EQ(self->runner_futures.value().size(), results.size()); int n = results.size(); std::string name = self->ctx->task_name.value(); const ffi::Function& logger = self->ctx->logger; @@ -156,8 +156,9 @@ void TaskSchedulerNode::Tune(ffi::Array ctxs, ffi::Array ffi::Array measure_callbacks, ffi::Optional database, ffi::Optional cost_model) { - CHECK_EQ(ctxs.size(), task_weights.size()) << "ValueError: `task_weights` must have the same " - "length as `ctxs`"; + TVM_FFI_CHECK_EQ(ctxs.size(), task_weights.size(), ValueError) + << "`task_weights` must have the same " + "length as `ctxs`"; int n_tasks = this->remaining_tasks_ = ctxs.size(); this->measure_callbacks_ = measure_callbacks; this->database_ = database; @@ -191,8 +192,8 @@ void TaskSchedulerNode::Tune(ffi::Array ctxs, ffi::Array TVM_PY_LOG(INFO, this->logger) << "TaskScheduler picks Task #" << task_id << ": " << tasks_[task_id]->ctx->task_name; TaskRecordNode* task = tasks_[task_id].get(); - ICHECK(!task->is_terminated); - ICHECK(!task->runner_futures.defined()); + TVM_FFI_ICHECK(!task->is_terminated); + TVM_FFI_ICHECK(!task->runner_futures.defined()); if (static_cast(task->latency_ms.size()) >= max_trials_per_task) { TerminateTask(task_id); continue; @@ -223,7 +224,7 @@ void TaskSchedulerNode::Tune(ffi::Array ctxs, ffi::Array ffi::Array TaskSchedulerNode::JoinRunningTask(int task_id) { TaskRecordNode* task = this->tasks_[task_id].get(); - ICHECK(task->runner_futures.defined()); + TVM_FFI_ICHECK(task->runner_futures.defined()); ffi::Array results; { auto _ = Profiler::TimedScope("JoinRunnerFutures"); @@ -233,12 +234,12 @@ ffi::Array TaskSchedulerNode::JoinRunningTask(int task_id) { results.push_back(future->Result()); } } - ICHECK(task->measure_candidates.defined()); + TVM_FFI_ICHECK(task->measure_candidates.defined()); task->ctx->search_strategy.value()->NotifyRunnerResults(task->measure_candidates.value(), results); - ICHECK(task->builder_results.defined()); - ICHECK_EQ(results.size(), task->measure_candidates.value().size()); - ICHECK_EQ(results.size(), task->builder_results.value().size()); + TVM_FFI_ICHECK(task->builder_results.defined()); + TVM_FFI_ICHECK_EQ(results.size(), task->measure_candidates.value().size()); + TVM_FFI_ICHECK_EQ(results.size(), task->builder_results.value().size()); for (const MeasureCallback& callback : this->measure_callbacks_) { callback->Apply(ffi::GetRef(this), task_id, task->measure_candidates.value(), task->builder_results.value(), results); @@ -264,7 +265,7 @@ void TaskSchedulerNode::TouchTask(int task_id) { void TaskSchedulerNode::TerminateTask(int task_id) { TaskRecordNode* task = this->tasks_[task_id].get(); - ICHECK(!task->is_terminated); + TVM_FFI_ICHECK(!task->is_terminated); task->is_terminated = true; --this->remaining_tasks_; TVM_PY_LOG_CLEAR_SCREEN(this->logger); @@ -335,7 +336,7 @@ void TaskSchedulerNode::PrintTuningStatistics() { TaskScheduler TaskScheduler::PyTaskScheduler( ffi::Function logger, PyTaskSchedulerNode::FNextTaskId f_next_task_id, PyTaskSchedulerNode::FJoinRunningTask f_join_running_task, PyTaskSchedulerNode::FTune f_tune) { - CHECK(f_next_task_id != nullptr) << "ValueError: next_task_id is not defined"; + TVM_FFI_CHECK(f_next_task_id != nullptr, ValueError) << "next_task_id is not defined"; ObjectPtr n = ffi::make_object(); n->logger = logger; n->f_next_task_id = f_next_task_id; @@ -345,7 +346,8 @@ TaskScheduler TaskScheduler::PyTaskScheduler( } int PyTaskSchedulerNode::NextTaskId() { - CHECK(f_next_task_id != nullptr) << "PyTaskScheduler's NextTaskId method not implemented!"; + TVM_FFI_ICHECK(f_next_task_id != nullptr) + << "PyTaskScheduler's NextTaskId method not implemented!"; return f_next_task_id(); } diff --git a/src/s_tir/meta_schedule/trace_apply.cc b/src/s_tir/meta_schedule/trace_apply.cc index aef2d209b0ce..ff373bdff468 100644 --- a/src/s_tir/meta_schedule/trace_apply.cc +++ b/src/s_tir/meta_schedule/trace_apply.cc @@ -168,7 +168,7 @@ std::vector ApplyAnchorTrace(Schedule sch, Trace anchor_trace) { auto block = Downcast(inputs[0]); auto block_sref = sch->GetSRef(block); if (!CanReverseComputeInline(sch->state(), block_sref)) { - ICHECK(CanComputeInline(sch->state(), block_sref)); + TVM_FFI_ICHECK(CanComputeInline(sch->state(), block_sref)); sch->ComputeInline(block); continue; } @@ -178,7 +178,7 @@ std::vector ApplyAnchorTrace(Schedule sch, Trace anchor_trace) { auto block_sref = sch->GetSRef(block); auto state = sch->state(); if (!CanComputeInline(state, block_sref)) { - ICHECK(IsOutputBlock(state, block_sref, GetScopeRoot(state, block_sref, false))) + TVM_FFI_ICHECK(IsOutputBlock(state, block_sref, GetScopeRoot(state, block_sref, false))) << "If a spatial block cannot be inlined, it should be the output block"; if (CanReverseComputeInline(sch->state(), block_sref)) { sch->ReverseComputeInline(block); @@ -197,7 +197,7 @@ std::vector ApplyAnchorTrace(Schedule sch, Trace anchor_trace) { // violates the assumption made by TranslateAddOutputRVs: old_outputs.size() == // new_outputs.size(). We workaround this problem by assuming that the prefix of the "new" // outputs matches with the "old" outputs, and truncating the new outputs accordingly. - ICHECK(inst->outputs.size() <= outputs.size()); + TVM_FFI_ICHECK(inst->outputs.size() <= outputs.size()); TranslateAddOutputRVs( inst->outputs, ffi::Array(outputs.begin(), outputs.begin() + inst->outputs.size()), &rv_map); @@ -238,7 +238,7 @@ void ScheduleUsingAnchorTrace(Schedule sch, const Trace& anchor_trace, const tvm InlinePostBlocks(sch, anchor_trace, target); auto unscheduled_blocks = ApplyAnchorTrace(sch, anchor_trace); - ICHECK(unscheduled_blocks.size() <= 1) + TVM_FFI_ICHECK(unscheduled_blocks.size() <= 1) << "All blocks should have been scheduled or only one (fused) spatial block can remain " "unscheduled at this point."; @@ -257,8 +257,8 @@ void ScheduleUsingAnchorTrace(Schedule sch, const Trace& anchor_trace, const tvm sch->Parallel(sch->Fuse(sch->GetLoops(last_block))); } else if (IsGPUTarget(target->kind->name)) { auto max_threads_per_block = target->GetAttr("max_threads_per_block"); - ICHECK(max_threads_per_block.defined()) - << "ValueError: missing attribute `max_threads_per_block` in the target"; + TVM_FFI_CHECK(max_threads_per_block.defined(), ValueError) + << "missing attribute `max_threads_per_block` in the target"; auto auto_bind_rule = ScheduleRule::AutoBind(/*max_threadblocks=*/256, diff --git a/src/s_tir/meta_schedule/tune_context.cc b/src/s_tir/meta_schedule/tune_context.cc index b6b61663eeab..9a935e931c4f 100644 --- a/src/s_tir/meta_schedule/tune_context.cc +++ b/src/s_tir/meta_schedule/tune_context.cc @@ -31,7 +31,8 @@ TuneContext::TuneContext(ffi::Optional mod, ffi::Optional targ ffi::Optional search_strategy, ffi::Optional task_name, int num_threads, TRandState rand_state, ffi::Function logger) { - CHECK(rand_state == -1 || rand_state >= 0) << "ValueError: Invalid random state: " << rand_state; + TVM_FFI_CHECK(rand_state == -1 || rand_state >= 0, ValueError) + << "Invalid random state: " << rand_state; ObjectPtr n = ffi::make_object(); n->mod = mod; n->target = target; diff --git a/src/s_tir/meta_schedule/utils.h b/src/s_tir/meta_schedule/utils.h index 7cad637d4c3b..6b2dd3c96f47 100644 --- a/src/s_tir/meta_schedule/utils.h +++ b/src/s_tir/meta_schedule/utils.h @@ -90,7 +90,7 @@ class PyLogMessage { : filename_(filename), lineno_(lineno), logger_(logger), logging_level_(logging_level) {} TVM_NO_INLINE ~PyLogMessage() { - ICHECK(logging_level_ != Level::CLEAR) + TVM_FFI_ICHECK(logging_level_ != Level::CLEAR) << "Cannot use CLEAR as logging level in TVM_PY_LOG, please use TVM_PY_LOG_CLEAR_SCREEN."; if (this->logger_ != nullptr) { logger_(static_cast(logging_level_), std::string(filename_), lineno_, stream_.str()); @@ -142,7 +142,7 @@ inline bool using_ipython() { inline void print_interactive_table(const ffi::String& data) { const auto f_print_interactive_table = tvm::ffi::Function::GetGlobal("s_tir.meta_schedule.print_interactive_table"); - ICHECK(f_print_interactive_table.has_value()) + TVM_FFI_ICHECK(f_print_interactive_table.has_value()) << "Cannot find print_interactive_table function in registry."; (*f_print_interactive_table)(data); } @@ -395,10 +395,10 @@ inline int GetTargetNumCores(const Target& target) { int num_cores = target->GetAttr("num-cores").value_or(-1).IntValue(); if (num_cores == -1) { static const auto f_cpu_count = tvm::ffi::Function::GetGlobal("s_tir.meta_schedule.cpu_count"); - ICHECK(f_cpu_count.has_value()) - << "ValueError: Cannot find the packed function \"s_tir.meta_schedule._cpu_count\""; + TVM_FFI_CHECK(f_cpu_count.has_value(), ValueError) + << "Cannot find the packed function \"s_tir.meta_schedule._cpu_count\""; num_cores = (*f_cpu_count)(false).cast(); - LOG(FATAL) + TVM_FFI_THROW(InternalError) << "Target does not have attribute \"num-cores\", physical core number must be " "defined! For example, on the local machine, the target must be \"llvm -num-cores " << num_cores << "\""; @@ -413,7 +413,7 @@ inline int GetTargetNumCores(const Target& target) { */ inline double GetRunMsMedian(const RunnerResult& runner_result) { ffi::Array run_secs = runner_result->run_secs.value(); - ICHECK(!run_secs.empty()); + TVM_FFI_ICHECK(!run_secs.empty()); std::vector v; v.reserve(run_secs.size()); std::transform(run_secs.begin(), run_secs.end(), std::back_inserter(v), @@ -434,7 +434,7 @@ inline double GetRunMsMedian(const RunnerResult& runner_result) { */ inline ffi::Array AsFloatArray(const ObjectRef& obj) { const ffi::ArrayObj* arr = obj.as(); - ICHECK(arr) << "TypeError: Expect an array, but gets: " << obj->GetTypeKey(); + TVM_FFI_CHECK(arr, TypeError) << "Expect an array, but gets: " << obj->GetTypeKey(); ffi::Array results; results.reserve(arr->size()); for (Any val : *arr) { @@ -444,7 +444,8 @@ inline ffi::Array AsFloatArray(const ObjectRef& obj) { } else if (auto opt_float_imm = val.try_cast()) { return *std::move(opt_float_imm); } else { - LOG(FATAL) << "TypeError: Expect an array of float or int, but gets: " << val.GetTypeKey(); + TVM_FFI_THROW(TypeError) << "Expect an array of float or int, but gets: " + << val.GetTypeKey(); TVM_FFI_UNREACHABLE(); } }(); @@ -461,7 +462,7 @@ inline ffi::Array AsFloatArray(const ObjectRef& obj) { */ inline ffi::Array AsIntArray(const ObjectRef& obj) { const ffi::ArrayObj* arr = obj.as(); - ICHECK(arr) << "TypeError: Expect an array, but gets: " << obj->GetTypeKey(); + TVM_FFI_CHECK(arr, TypeError) << "Expect an array, but gets: " << obj->GetTypeKey(); ffi::Array results; results.reserve(arr->size()); for (Any val : *arr) { @@ -469,7 +470,7 @@ inline ffi::Array AsIntArray(const ObjectRef& obj) { if (auto opt_int_imm = val.try_cast()) { return (*opt_int_imm)->value; } else { - LOG(FATAL) << "TypeError: Expect an array of integers, but gets: " << val.GetTypeKey(); + TVM_FFI_THROW(TypeError) << "Expect an array of integers, but gets: " << val.GetTypeKey(); TVM_FFI_UNREACHABLE(); } }(); @@ -555,15 +556,15 @@ inline ScheduleRule GetDefaultAutoInline(const std::string& target_name) { } else if (IsGPUTarget(target_name)) { rules = ScheduleRule::DefaultCUDA(); } else { - LOG(FATAL) << "ValueError: Unsupported target: " << target_name; + TVM_FFI_THROW(ValueError) << "Unsupported target: " << target_name; } for (const ScheduleRule& rule : rules) { if (rule->GetTypeKey() == "s_tir.meta_schedule.AutoInline") { return rule; } } - LOG(FATAL) << "ValueError: AutoInline rule is not found in the default rules for target: " - << target_name; + TVM_FFI_THROW(ValueError) << "AutoInline rule is not found in the default rules for target: " + << target_name; throw; } @@ -623,7 +624,7 @@ class SBlockCollector : public tir::StmtVisitor { /*! \brief Override the Stmt visiting behaviour */ void VisitStmt_(const tir::SBlockNode* block) override { tir::StmtVisitor::VisitStmt_(block); - CHECK(block_names_.count(block->name_hint) == 0) + TVM_FFI_ICHECK(block_names_.count(block->name_hint) == 0) << "Duplicated block name " << block->name_hint << " in function " << func_name_ << " not supported!"; block_names_.insert(block->name_hint); diff --git a/src/s_tir/schedule/analysis/analysis.cc b/src/s_tir/schedule/analysis/analysis.cc index a00958e5246e..9a7763660ef6 100644 --- a/src/s_tir/schedule/analysis/analysis.cc +++ b/src/s_tir/schedule/analysis/analysis.cc @@ -48,9 +48,10 @@ const PrimFuncNode* GetRootPrimFunc(const IRModule& mod, const StmtNode* root_bl } } } - LOG(FATAL) << "IndexError: Could not get the corresponding function in the schedule state of the " - "statement:\n" - << ffi::GetRef(root_block); + TVM_FFI_THROW(IndexError) + << "Could not get the corresponding function in the schedule state of the " + "statement:\n" + << ffi::GetRef(root_block); throw; } @@ -127,7 +128,7 @@ ScopeBlockLoopInfo GetScopeBlockLoopInfo(const SBlock& scope_block) { result.realizes.push_back(ffi::GetRef(realize)); const ffi::Array& iter_vars = realize->block->iter_vars; const ffi::Array& iter_values = realize->iter_values; - ICHECK_EQ(iter_vars.size(), iter_values.size()); + TVM_FFI_ICHECK_EQ(iter_vars.size(), iter_values.size()); int n = realize->iter_values.size(); for (int i = 0; i < n; ++i) { const IterVar& iter_var = iter_vars[i]; @@ -162,7 +163,8 @@ void CheckSRefHigherOrEqual(const StmtSRef& sref_a, const StmtSRef& sref_b) { return; } } - CHECK(false) << "Expect StmtSRef " << sref_a << "to be higher than or equal to " << sref_b; + TVM_FFI_ICHECK(false) << "Expect StmtSRef " << sref_a << "to be higher than or equal to " + << sref_b; } /*! @@ -428,7 +430,8 @@ void CheckSubtreeCompactDataflow(const ScheduleState& self, const StmtSRef& subt violate_block_(std::move(violate_block)), local_complete_block_code_(local_complete_block_code), local_reduction_block_code_(local_reduction_block_code) { - ICHECK(subtree_root_->IsInstance() || subtree_root_->IsInstance()); + TVM_FFI_ICHECK(subtree_root_->IsInstance() || + subtree_root_->IsInstance()); } ffi::String FastErrorString() const final { return "ScheduleError: The queried subtree root in SRef tree does not have compact dataflow, " @@ -693,7 +696,7 @@ ffi::Map GetBindings(const SBlockRealize& realize) { const SBlockNode* block = realize->block.get(); const ffi::Array& all_lhs = block->iter_vars; const ffi::Array& all_rhs = realize->iter_values; - ICHECK_EQ(all_lhs.size(), all_rhs.size()); + TVM_FFI_ICHECK_EQ(all_lhs.size(), all_rhs.size()); ffi::Map result; for (int i = 0, n = all_lhs.size(); i < n; ++i) { const IterVar& lhs = all_lhs[i]; @@ -707,12 +710,12 @@ bool GetVarsTouchedByBlockIters(const SBlockRealize& block_realize, std::unordered_set* data_par_vars, std::unordered_set* reduce_vars) { SBlock block = block_realize->block; - ICHECK(block_realize->block.same_as(block)) - << "ValueError: The input `block_realize` is required to be the exact BlockRealize of the " + TVM_FFI_CHECK(block_realize->block.same_as(block), ValueError) + << "The input `block_realize` is required to be the exact BlockRealize of the " "input block"; bool has_block_vars_of_other_types = false; - ICHECK_EQ(block->iter_vars.size(), block_realize->iter_values.size()); + TVM_FFI_ICHECK_EQ(block->iter_vars.size(), block_realize->iter_values.size()); int n = static_cast(block->iter_vars.size()); for (int i = 0; i < n; ++i) { const IterVar& iter_var = block->iter_vars[i]; @@ -802,7 +805,7 @@ ffi::Array GetChildBlockRealizeOnSRefTree(const StmtSRef& parent_ const auto* block = static_cast(parent_sref->stmt); return Collector::Collect(block->body); } - ICHECK(false) << "Unreachable"; + TVM_FFI_ICHECK(false) << "Unreachable"; throw; } @@ -872,8 +875,8 @@ SBlockRealize GetSBlockRealize(const ScheduleState& self, const StmtSRef& block_ } else { BlockRealizeFinder finder(block); finder(ffi::GetRef(block_sref->parent->stmt)); - ICHECK(finder.result != nullptr) - << "InternalError: Cannot find the BlockRealize of block " << ffi::GetRef(block); + TVM_FFI_CHECK(finder.result != nullptr, InternalError) + << "Cannot find the BlockRealize of block " << ffi::GetRef(block); return ffi::GetRef(finder.result); } } @@ -895,7 +898,7 @@ IterVarType GetLoopIterType(const StmtSRef& loop_sref) { if (const auto* realize = obj.as()) { const SBlockNode* block = realize->block.get(); // Number of block vars and their bindings - ICHECK_EQ(realize->iter_values.size(), block->iter_vars.size()); + TVM_FFI_ICHECK_EQ(realize->iter_values.size(), block->iter_vars.size()); size_t n = realize->iter_values.size(); for (size_t i = 0; i < n; ++i) { const IterVar& iter_var = block->iter_vars[i]; @@ -933,7 +936,8 @@ IterVarType GetLoopIterType(const StmtSRef& loop_sref) { } StmtSRef GetSRefLowestCommonAncestor(const ffi::Array& srefs) { - CHECK(!srefs.empty()) << "ValueError: The input array is required to have at least one sref"; + TVM_FFI_CHECK(!srefs.empty(), ValueError) + << "The input array is required to have at least one sref"; std::unordered_map sref_visited_cnt; for (const StmtSRef& sref : srefs) { @@ -948,7 +952,7 @@ StmtSRef GetSRefLowestCommonAncestor(const ffi::Array& srefs) { while (p != nullptr && sref_visited_cnt[p] != n_sref) { p = p->parent; } - ICHECK(p != nullptr); + TVM_FFI_ICHECK(p != nullptr); return ffi::GetRef(p); } @@ -989,7 +993,7 @@ std::pair, std::vector> CollectComputeLocation( ffi::Array loop_srefs = GetLoops(consumers[0]); size_t lca_pos = std::find(loop_srefs.begin(), loop_srefs.end(), loop_boundary) - loop_srefs.begin(); - ICHECK_LT(lca_pos, loop_srefs.size()); + TVM_FFI_ICHECK_LT(lca_pos, loop_srefs.size()); size_t n_candidate = lca_pos + 1; // Step 5. Find the position of the deepest data-parallel loop among the candidate loops. This @@ -1075,7 +1079,7 @@ ffi::Array GetOutputBlocks(const ScheduleState& self, const SBlockNode void VisitStmt_(const SBlockNode* block) override { auto it = self_->stmt2ref.find(block); - ICHECK(it != self_->stmt2ref.end()); + TVM_FFI_ICHECK(it != self_->stmt2ref.end()); auto block_sref = it->second; if (block_sref->parent != nullptr) { StmtSRef scope_root_sref = @@ -1722,7 +1726,7 @@ TensorIntrinDescInfo ExtractTensorIntrinDescInfo(arith::Analyzer* analyzer, const PrimFunc& desc_func) { TensorIntrinDescInfo info; const auto* desc_scope_realize = desc_func->body.as(); - ICHECK(desc_scope_realize); + TVM_FFI_ICHECK(desc_scope_realize); { auto f_visit = [&](const ObjectRef& obj) -> bool { // Extract the block @@ -1742,7 +1746,7 @@ TensorIntrinDescInfo ExtractTensorIntrinDescInfo(arith::Analyzer* analyzer, }; tir::PostOrderVisit(desc_scope_realize->block->body, f_visit); std::reverse(info.desc_loops.begin(), info.desc_loops.end()); - ICHECK(info.desc_block); + TVM_FFI_ICHECK(info.desc_block); } return info; } @@ -1792,8 +1796,8 @@ ffi::Optional GetTensorizeLoopMapping(const s_tir::ScheduleState& const std::vector iter_types_block = GetSBlockVarTypes(block_sref); const std::vector iter_types_desc = GetSBlockVarTypes(desc_block->block.get()); - ICHECK(desc_loops.size() == static_cast(n_desc_vars)); - ICHECK(block_loops.size() == iter_types_block.size()); + TVM_FFI_ICHECK(desc_loops.size() == static_cast(n_desc_vars)); + TVM_FFI_ICHECK(block_loops.size() == iter_types_block.size()); // We assume that the orders of iter_vars in the target and the desc block are consistent. // Based on that assumption, the following logic supports arbitrary permutations of a loop order, @@ -1976,7 +1980,7 @@ class AutoTensorizeMappingProposer { } // Step 2: Compute the buffer mask - ICHECK_EQ(rhs_buffer_index.size(), lhs_buffer_index.size()); + TVM_FFI_ICHECK_EQ(rhs_buffer_index.size(), lhs_buffer_index.size()); int num_buffers = rhs_buffer_index.size(); std::unordered_map> rhs_buffer_masks, lhs_buffer_masks; // helper function to initialize or update the buffer mask @@ -1994,13 +1998,14 @@ class AutoTensorizeMappingProposer { if (const VarNode* var_node = rhs_index.as()) { update_mask(var_node, &rhs_buffer_masks, rhs_buffer_index.at(rhs_buffer)); } else { - LOG(FATAL) << "ValueError: Buffer index " << rhs_index - << " other that variables in tensor intrinsics is not supported."; + TVM_FFI_THROW(ValueError) + << "Buffer index " << rhs_index + << " other that variables in tensor intrinsics is not supported."; } } auto lhs_buffer_it = extractor_->rhs_buffer_map_.find(rhs_buffer); - ICHECK(lhs_buffer_it != extractor_->rhs_buffer_map_.end()); + TVM_FFI_ICHECK(lhs_buffer_it != extractor_->rhs_buffer_map_.end()); const Buffer& lhs_buffer = lhs_buffer_it->second; for (const PrimExpr& index : extractor_->lhs_buffer_indices_map_.at(lhs_buffer)) { PreOrderVisit(index, [&](const ObjectRef& obj) -> bool { diff --git a/src/s_tir/schedule/analysis/layout.cc b/src/s_tir/schedule/analysis/layout.cc index 6c1feb10f706..2b353399377f 100644 --- a/src/s_tir/schedule/analysis/layout.cc +++ b/src/s_tir/schedule/analysis/layout.cc @@ -31,7 +31,7 @@ using namespace tvm::tir; */ ffi::Array GetStrides(const Buffer& buffer) { if (!buffer->strides.empty()) { - ICHECK_EQ(buffer->strides.size(), buffer->shape.size()); + TVM_FFI_ICHECK_EQ(buffer->strides.size(), buffer->shape.size()); return buffer->strides; } int ndim = buffer->shape.size(); @@ -86,7 +86,7 @@ class SplitExprCollector { if (iter_sum_exprs.empty()) { return {}; } - ICHECK_EQ(iter_sum_exprs.size(), 1); + TVM_FFI_ICHECK_EQ(iter_sum_exprs.size(), 1); if (iter_sum_exprs[0]->args.size() == 0) { return {}; } @@ -111,7 +111,7 @@ class SplitExprCollector { } else if (auto iter_sum_expr = expr->source->source.as()) { Visit(iter_sum_expr.value()); } else { - ICHECK(false) << "Unexpected type: " << expr->source->source->GetTypeKey(); + TVM_FFI_ICHECK(false) << "Unexpected type: " << expr->source->source->GetTypeKey(); } } @@ -181,7 +181,7 @@ ffi::Optional SuggestIndexMap(const Buffer& buffer, const ffi::Array

shape, // analyzer // ](ffi::Array indices) -> ffi::Array { - ICHECK_EQ(indices.size(), shape.size()); + TVM_FFI_ICHECK_EQ(indices.size(), shape.size()); for (int i = 0, n = indices.size(); i < n; ++i) { analyzer->Bind(indices[i], Range::FromMinExtent(0, shape[i])); } @@ -209,7 +209,7 @@ ffi::Optional SuggestIndexMap(const Buffer& buffer, const ffi::Array

shape, analyzer](ffi::Array indices) -> ffi::Array { - ICHECK_EQ(indices.size(), split_exprs.size()); + TVM_FFI_ICHECK_EQ(indices.size(), split_exprs.size()); // Step 6.1: Reorder the indices according to `inverse_order`. This is the inverse of Step 5.3. // After the inverse permutation, indices[i] corresponds to split_exprs[i] ffi::Array inv_permuted_indices; diff --git a/src/s_tir/schedule/analysis/reducer.cc b/src/s_tir/schedule/analysis/reducer.cc index 547dc7d8b89b..7559a5bfb9d7 100644 --- a/src/s_tir/schedule/analysis/reducer.cc +++ b/src/s_tir/schedule/analysis/reducer.cc @@ -263,7 +263,7 @@ class PatternMatcher : public ExprVisitor { this->match_success_ = true; this->filled_map_.clear(); - ICHECK_EQ(pattern_.size(), exprs_to_match.size()); + TVM_FFI_ICHECK_EQ(pattern_.size(), exprs_to_match.size()); int n_buffers = pattern_.size(); for (int i = 0; i < n_buffers; ++i) { this->expr_to_match_ = exprs_to_match[i]; @@ -273,8 +273,8 @@ class PatternMatcher : public ExprVisitor { PrimExpr Eval(const Var& var) { auto it = filled_map_.find(var.operator->()); - ICHECK(it != filled_map_.end()) << "Unknown pattern variable"; - ICHECK(match_success_) << "Match failed"; + TVM_FFI_ICHECK(it != filled_map_.end()) << "Unknown pattern variable"; + TVM_FFI_ICHECK(match_success_) << "Match failed"; return it->second; } @@ -335,10 +335,10 @@ void ErrorRFactorCrossThreadReductionNotApplicable(const ffi::Optionalmod, std::move(block), violated_cond); } else { - LOG(FATAL) << "ValueError: Cross-thread reduction cannot be applied to the block " - << block->name_hint << " because the block violates the condition #" << violated_cond - << ".\n" - << kRFactorCrossThreadReductionApplicableBlockDef; + TVM_FFI_THROW(ValueError) << "Cross-thread reduction cannot be applied to the block " + << block->name_hint << " because the block violates the condition #" + << violated_cond << ".\n" + << kRFactorCrossThreadReductionApplicableBlockDef; } } @@ -425,7 +425,7 @@ void ExtractReductionUpdates(const ffi::Optional& self, SBlock bl } } for (int i = 0; i < n_buffers; ++i) { - ICHECK((*updates)[i].defined()); + TVM_FFI_ICHECK((*updates)[i].defined()); } } @@ -464,7 +464,7 @@ std::pair, ffi::Array> GetInitValuesAndUpdates const auto* let = block->body.as(); ExtractReductionUpdates(self, block, let, n_buffers, &updates, &buf2index); } - ICHECK_EQ(updates.size(), n_buffers); + TVM_FFI_ICHECK_EQ(updates.size(), n_buffers); // Step 3. Set the init values according to the buffer order in `updates`, with the help of the // mapping `buf2index`. @@ -477,7 +477,7 @@ std::pair, ffi::Array> GetInitValuesAndUpdates // - Check buffers do not duplicate const ffi::Array& expected_shape = updates[0]->buffer->shape; const ffi::Array& expected_indices = updates[0]->indices; - ICHECK_EQ(expected_shape.size(), expected_indices.size()); + TVM_FFI_ICHECK_EQ(expected_shape.size(), expected_indices.size()); int n_dim = expected_indices.size(); arith::Analyzer ana; for (int i = 0; i < n_buffers; ++i) { @@ -503,12 +503,12 @@ std::pair, ffi::Array> GetInitValuesAndUpdates ErrorRFactorCrossThreadReductionNotApplicable(self, std::move(block), /*violated_cond=*/10); } int idx = it->second; - ICHECK(updates[idx]->buffer.same_as(inits[i]->buffer)); - ICHECK(!init_values[idx].defined()); + TVM_FFI_ICHECK(updates[idx]->buffer.same_as(inits[i]->buffer)); + TVM_FFI_ICHECK(!init_values[idx].defined()); init_values.Set(idx, inits[i]->value); } for (int i = 0; i < n_buffers; ++i) { - ICHECK(init_values[i].defined()); + TVM_FFI_ICHECK(init_values[i].defined()); } return std::make_pair(init_values, updates); @@ -574,9 +574,10 @@ bool ReductionIterNotIndexOutputBuffer(const SBlock& block) { bool write_is_covered_by_match_buffer = match_buffer_sources.count(store->buffer.get()) && buffer_written.count(match_buffer_sources.find(store->buffer.get())->second); - ICHECK(buffer_written.count(store->buffer.get()) || write_is_covered_by_match_buffer || - buffer_allocated.count(store->buffer.get())) - << "ValueError: The buffer \"" << store->buffer + TVM_FFI_CHECK(buffer_written.count(store->buffer.get()) || write_is_covered_by_match_buffer || + buffer_allocated.count(store->buffer.get()), + ValueError) + << "The buffer \"" << store->buffer << "\" is written in the block but is not in the block's signature nor is it covered by " "a match_buffer"; for (const PrimExpr& index : store->indices) { @@ -630,8 +631,9 @@ std::tuple, ffi::Array> GetReducerAn if (self.defined()) { throw NoMatchedReducerError(self.value()->mod, identities, combiners); } else { - LOG(FATAL) << "ValueError: No matched reducer for the identity and the combiner of the " - "reduction block. So rfactor and cross-thread reduction cannot be applied."; + TVM_FFI_THROW(ValueError) + << "No matched reducer for the identity and the combiner of the " + "reduction block. So rfactor and cross-thread reduction cannot be applied."; } } return std::make_tuple(std::move(reducer), std::move(combiner_lhs), std::move(combiner_rhs)); @@ -644,7 +646,7 @@ bool MatchReducer(const CommReducer& reducer, const ffi::Array& identi const ffi::Array& buf_loads, ffi::Array* lhs, ffi::Array* rhs) { ExprDeepEqual equal; - ICHECK_EQ(identities.size(), combined_values.size()); + TVM_FFI_ICHECK_EQ(identities.size(), combined_values.size()); int n_buffers = identities.size(); for (int i = 0; i < n_buffers; ++i) { if (!equal(reducer->identity_element[i], identities[i])) { diff --git a/src/s_tir/schedule/analysis/verify.cc b/src/s_tir/schedule/analysis/verify.cc index 91b10d7d5f95..b31f624c3bca 100644 --- a/src/s_tir/schedule/analysis/verify.cc +++ b/src/s_tir/schedule/analysis/verify.cc @@ -32,40 +32,41 @@ class SRefTreeVerifier : public StmtVisitor { void Verify() { VisitPrimFuncs(self_->mod, [this](const PrimFuncNode* func) { this->VisitStmt(func->body); }); - ICHECK_EQ(n_sref_visited_, static_cast(self_->stmt2ref.size())); + TVM_FFI_ICHECK_EQ(n_sref_visited_, static_cast(self_->stmt2ref.size())); for (const auto& kv : self_->block_info) { const StmtSRef& sref = kv.first; - ICHECK(sref->stmt != nullptr) - << "InternalError: An expired sref is found in the block_scope mapping"; + TVM_FFI_CHECK(sref->stmt != nullptr, InternalError) + << "An expired sref is found in the block_scope mapping"; auto it = self_->stmt2ref.find(sref->stmt); - ICHECK(it != self_->stmt2ref.end()) - << "InternalError: The sref points to a statement that does not exist in stmt2ref"; + TVM_FFI_CHECK(it != self_->stmt2ref.end(), InternalError) + << "The sref points to a statement that does not exist in stmt2ref"; const StmtSRef& sref2 = it->second; - ICHECK(sref.same_as(sref2)) - << "InternalError: The sref points to a statement whose corresponding sref in stmt2ref " + TVM_FFI_CHECK(sref.same_as(sref2), InternalError) + << "The sref points to a statement whose corresponding sref in stmt2ref " "is not the same object as itself"; } - ICHECK_EQ(n_block_sref_visited_, static_cast(self_->block_info.size())); + TVM_FFI_ICHECK_EQ(n_block_sref_visited_, static_cast(self_->block_info.size())); } void VisitStmt_(const SBlockNode* block) final { if (init_block_depth_) { - ICHECK(!self_->stmt2ref.count(block)) << "InternalError: A block inside init block has its " - "corresponding sref, which is not allowed"; + TVM_FFI_CHECK(!self_->stmt2ref.count(block), InternalError) + << "A block inside init block has its " + "corresponding sref, which is not allowed"; StmtVisitor::VisitStmt_(block); return; } - ICHECK(self_->stmt2ref.count(block)) - << "InternalError: A BlockNode should appear in sref map, but it didn't\n" + TVM_FFI_CHECK(self_->stmt2ref.count(block), InternalError) + << "A BlockNode should appear in sref map, but it didn't\n" << ffi::GetRef(block); ++n_sref_visited_; ++n_block_sref_visited_; const StmtSRef& sref = self_->stmt2ref.at(block); - ICHECK(self_->block_info.count(sref)) - << "InternalError: Cannot find scope information of the BlockNode:\n" + TVM_FFI_CHECK(self_->block_info.count(sref), InternalError) + << "Cannot find scope information of the BlockNode:\n" << ffi::GetRef(block); - ICHECK(sref->parent == ancestors_.back()) - << "InternalError: Parent information mismatch for BlockNode:\n" + TVM_FFI_CHECK(sref->parent == ancestors_.back(), InternalError) + << "Parent information mismatch for BlockNode:\n" << ffi::GetRef(block) << "\nIts parent is supposed to be:\n" << ffi::GetRef(ancestors_.back()->stmt) << "\nHowever, its parent is incorrect and is:\n" @@ -83,19 +84,20 @@ class SRefTreeVerifier : public StmtVisitor { void VisitStmt_(const ForNode* loop) final { if (init_block_depth_) { - ICHECK(!self_->stmt2ref.count(loop)) << "InternalError: A loop inside init block has its " - "corresponding sref, which is not allowed"; + TVM_FFI_CHECK(!self_->stmt2ref.count(loop), InternalError) + << "A loop inside init block has its " + "corresponding sref, which is not allowed"; StmtVisitor::VisitStmt_(loop); return; } - ICHECK(self_->stmt2ref.count(loop)) - << "InternalError: A ForNode should appear in sref map, but it didn't\n" + TVM_FFI_CHECK(self_->stmt2ref.count(loop), InternalError) + << "A ForNode should appear in sref map, but it didn't\n" << ffi::GetRef(loop); ++n_sref_visited_; const StmtSRef& sref = self_->stmt2ref.at(loop); ffi::Optional stmt = std::nullopt; - ICHECK(sref->parent == ancestors_.back()) - << "InternalError: Parent information mismatch for ForNode:\n" + TVM_FFI_CHECK(sref->parent == ancestors_.back(), InternalError) + << "Parent information mismatch for ForNode:\n" << ffi::GetRef(loop) << "\nIts parent is supposed to be:\n" << ffi::GetRef(ancestors_.back()->stmt) << "\nHowever, its parent is incorrect and is:\n" @@ -118,15 +120,15 @@ class SRefTreeVerifier : public StmtVisitor { StmtSRef sref{nullptr}; if (const auto* realize = child.as()) { const auto* block = realize->block.get(); - ICHECK(self_->stmt2ref.count(block)); + TVM_FFI_ICHECK(self_->stmt2ref.count(block)); sref = self_->stmt2ref.at(block); } else if (child->IsInstance()) { - ICHECK(self_->stmt2ref.count(child.get())); + TVM_FFI_ICHECK(self_->stmt2ref.count(child.get())); sref = self_->stmt2ref.at(child.get()); } else { continue; } - ICHECK_EQ(sref->seq_index, i) << "InternalError: A StmtSRef has incorrect seq_index"; + TVM_FFI_CHECK_EQ(sref->seq_index, i, InternalError) << "A StmtSRef has incorrect seq_index"; } StmtVisitor::VisitStmt_(seq_stmt); } @@ -195,7 +197,7 @@ void VerifyCachedFlags(const ScheduleState& self) { os << "- SBlockInfo not found:"; for (const StmtSRef& block_sref : block_info_not_found) { const auto* block = block_sref->StmtAs(); - ICHECK(block); + TVM_FFI_ICHECK(block); os << " " << block->name_hint; } os << std::endl; @@ -207,7 +209,7 @@ void VerifyCachedFlags(const ScheduleState& self) { bool expected = std::get<1>(record); bool actual = std::get<2>(record); const auto* block = block_sref->StmtAs(); - ICHECK(block); + TVM_FFI_ICHECK(block); os << " (" << block->name_hint << ", expected=" << expected << ", actual=" << actual << ")"; } os << std::endl; @@ -219,7 +221,7 @@ void VerifyCachedFlags(const ScheduleState& self) { bool expected = std::get<1>(record); bool actual = std::get<2>(record); const auto* block = block_sref->StmtAs(); - ICHECK(block); + TVM_FFI_ICHECK(block); os << " (" << block->name_hint << ", expected=" << expected << ", actual=" << actual << ")"; } os << std::endl; @@ -231,14 +233,14 @@ void VerifyCachedFlags(const ScheduleState& self) { bool expected = std::get<1>(record); bool actual = std::get<2>(record); const auto* block = block_sref->StmtAs(); - ICHECK(block); + TVM_FFI_ICHECK(block); os << " (" << block->name_hint << ", expected=" << expected << ", actual=" << actual << ")"; } os << std::endl; } - LOG(FATAL) << "Schedule verification failed. The IR is:\n" - << self->mod << "\nThe errors are:\n" - << os.str(); + TVM_FFI_THROW(InternalError) << "Schedule verification failed. The IR is:\n" + << self->mod << "\nThe errors are:\n" + << os.str(); throw; } diff --git a/src/s_tir/schedule/concrete_schedule.cc b/src/s_tir/schedule/concrete_schedule.cc index 266c7ff46425..9d5068c61b62 100644 --- a/src/s_tir/schedule/concrete_schedule.cc +++ b/src/s_tir/schedule/concrete_schedule.cc @@ -320,9 +320,10 @@ SBlockRV ConcreteScheduleNode::GetSBlock(const ffi::String& name, } else if (func_working_on_.has_value()) { gv = this->func_working_on_.value(); } else { - LOG(FATAL) << "ValueError: `get_sblock` does not know which function to be working on. Please " - "specify the function name explicitly, or call `work_on` to specify the function " - "before using `get_sblock`."; + TVM_FFI_THROW(ValueError) + << "`get_sblock` does not know which function to be working on. Please " + "specify the function name explicitly, or call `work_on` to specify the function " + "before using `get_sblock`."; } ffi::Array blocks = s_tir::GetSBlocks(this->state_, name, gv); if (blocks.size() != 1) { @@ -379,7 +380,7 @@ ffi::Array ConcreteScheduleNode::GetOutputBlocks(const SBlockRV& scope /******** Schedule: Transform loops ********/ LoopRV ConcreteScheduleNode::Merge(const ffi::Array& loop_rvs) { - CHECK(loop_rvs.size() > 1) << "ValueError: 'merge' requires at least 2 loop(s)"; + TVM_FFI_CHECK(loop_rvs.size() > 1, ValueError) << "'merge' requires at least 2 loop(s)"; ffi::Array loop_srefs = this->GetSRefs(loop_rvs); StmtSRef result{nullptr}; TVM_TIR_SCHEDULE_BEGIN(); @@ -390,7 +391,7 @@ LoopRV ConcreteScheduleNode::Merge(const ffi::Array& loop_rvs) { } LoopRV ConcreteScheduleNode::Fuse(const ffi::Array& loop_rvs, bool preserve_unit_iters) { - CHECK(!loop_rvs.empty()) << "ValueError: 'fuse' requires at least 1 loop(s)"; + TVM_FFI_CHECK(!loop_rvs.empty(), ValueError) << "'fuse' requires at least 1 loop(s)"; ffi::Array loop_srefs = this->GetSRefs(loop_rvs); StmtSRef result{nullptr}; TVM_TIR_SCHEDULE_BEGIN(); @@ -944,8 +945,8 @@ Any ConcreteScheduleNode::CheckAndGetAnnotationValue(const ffi::Any& ann_val) { } if (const auto* expr = ann_val.as()) { - ICHECK(!expr->IsInstance()) - << "TypeError: ffi::String is expected, but gets StringImm"; + TVM_FFI_CHECK(!expr->IsInstance(), TypeError) + << "ffi::String is expected, but gets StringImm"; auto res_expr = this->Get(ffi::GetRef(expr)); // prefer to return int/float literals for annotations if (auto opt_intimm = res_expr.as()) { @@ -974,13 +975,13 @@ Any ConcreteScheduleNode::CheckAndGetAnnotationValue(const ffi::Any& ann_val) { } else if (auto opt_str = key.try_cast()) { result.Set(opt_str.value(), value); } else { - LOG(FATAL) << "TypeError: annotation dict key expect to be ffi::String or StringImm"; + TVM_FFI_THROW(TypeError) << "annotation dict key expect to be ffi::String or StringImm"; } } return result; } - LOG(FATAL) - << "TypeError: Only strings, integers, floats, ExprRVs and Arrays are supported for now, but " + TVM_FFI_THROW(TypeError) + << "Only strings, integers, floats, ExprRVs and Arrays are supported for now, but " << "gets: " << ann_val.GetTypeKey(); TVM_FFI_UNREACHABLE(); } diff --git a/src/s_tir/schedule/concrete_schedule.h b/src/s_tir/schedule/concrete_schedule.h index e475eb6aefc0..ba058637dc97 100644 --- a/src/s_tir/schedule/concrete_schedule.h +++ b/src/s_tir/schedule/concrete_schedule.h @@ -262,7 +262,7 @@ inline PrimExpr ConcreteScheduleNode::Get(const ExprRV& expr_rv) const { PrimExpr transformed = Substitute(expr_rv, [this](const Var& var) -> ffi::Optional { auto it = this->symbol_table_.find(var); if (it == this->symbol_table_.end()) { - LOG(FATAL) << "IndexError: Cannot find corresponding ExprRV: " << var; + TVM_FFI_THROW(IndexError) << "Cannot find corresponding ExprRV: " << var; } const ObjectRef& obj = (*it).second; const auto* int_imm = TVM_TYPE_AS(obj, IntImmNode); @@ -287,16 +287,16 @@ inline bool ConcreteScheduleNode::HasBlock(const SBlockRV& block_rv) const { inline StmtSRef ConcreteScheduleNode::GetSRef(const SBlockRV& block_rv) const { auto it = this->symbol_table_.find(block_rv); if (it == this->symbol_table_.end()) { - LOG(FATAL) << "IndexError: Cannot find corresponding SBlockRV: " << block_rv; + TVM_FFI_THROW(IndexError) << "Cannot find corresponding SBlockRV: " << block_rv; } const ObjectRef& obj = (*it).second; const auto* sref = obj.as(); if (sref == nullptr) { - LOG(FATAL) << "ValueError: SBlockRV's corresponding type is invalid: " - << (obj.defined() ? obj->GetTypeKey() : "None"); + TVM_FFI_THROW(ValueError) << "SBlockRV's corresponding type is invalid: " + << (obj.defined() ? obj->GetTypeKey() : "None"); } if (sref->stmt == nullptr) { - LOG(FATAL) << "ValueError: The block no longer exists in the IRModule"; + TVM_FFI_THROW(ValueError) << "The block no longer exists in the IRModule"; } return ffi::GetRef(sref); } @@ -306,7 +306,7 @@ inline StmtSRef ConcreteScheduleNode::GetSRef(const LoopRV& loop_rv) const { static StmtSRef root_mark = StmtSRef::RootMark(); auto it = this->symbol_table_.find(loop_rv); if (it == this->symbol_table_.end()) { - LOG(FATAL) << "IndexError: Cannot find corresponding LoopRV: " << loop_rv; + TVM_FFI_THROW(IndexError) << "Cannot find corresponding LoopRV: " << loop_rv; } const ObjectRef& obj = (*it).second; if (obj.same_as(inline_mark)) { @@ -317,11 +317,11 @@ inline StmtSRef ConcreteScheduleNode::GetSRef(const LoopRV& loop_rv) const { } const auto* sref = obj.as(); if (sref == nullptr) { - LOG(FATAL) << "ValueError: LoopRV's corresponding type is invalid: " - << (obj.defined() ? obj->GetTypeKey() : "None"); + TVM_FFI_THROW(ValueError) << "LoopRV's corresponding type is invalid: " + << (obj.defined() ? obj->GetTypeKey() : "None"); } if (sref->stmt == nullptr) { - LOG(FATAL) << "ValueError: The loop no longer exists in the IRModule"; + TVM_FFI_THROW(ValueError) << "The loop no longer exists in the IRModule"; } return ffi::GetRef(sref); } @@ -391,7 +391,7 @@ inline void ConcreteScheduleNode::RemoveFromSymbolTable(const ObjectRef& obj) { if (it != this->symbol_table_.end()) { this->symbol_table_.erase(obj); } else { - LOG(FATAL) << "IndexError: Cannot find the object in the symbol table: " << obj; + TVM_FFI_THROW(IndexError) << "Cannot find the object in the symbol table: " << obj; throw; } } diff --git a/src/s_tir/schedule/instruction.cc b/src/s_tir/schedule/instruction.cc index 7feb4b25ae3c..cb9f357de31b 100644 --- a/src/s_tir/schedule/instruction.cc +++ b/src/s_tir/schedule/instruction.cc @@ -48,7 +48,8 @@ using InstructionKindRegistry = AttrRegistryGet(name); - ICHECK(reg != nullptr) << "AttributeError: Instruction kind " << name << " is not registered"; + TVM_FFI_CHECK(reg != nullptr, AttributeError) + << "Instruction kind " << name << " is not registered"; return reg->inst_kind_; } @@ -65,7 +66,7 @@ InstructionKindRegEntry& InstructionKindRegEntry::RegisterOrGet(const ffi::Strin TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& obj, ReprPrinter* p) { const auto* self = obj.as(); - ICHECK_NOTNULL(self); + TVM_FFI_ICHECK_NOTNULL(self); ffi::Array inputs; inputs.reserve(self->inputs.size()); for (const Any& obj : self->inputs) { @@ -92,7 +93,8 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) } else if (obj.as()) { inputs.push_back(obj); } else { - LOG(FATAL) << "TypeError: Stringifying is not supported for type: " << obj.GetTypeKey(); + TVM_FFI_THROW(TypeError) + << "Stringifying is not supported for type: " << obj.GetTypeKey(); throw; } } diff --git a/src/s_tir/schedule/instruction_traits.h b/src/s_tir/schedule/instruction_traits.h index 395aab2d5ede..47968f4bc34e 100644 --- a/src/s_tir/schedule/instruction_traits.h +++ b/src/s_tir/schedule/instruction_traits.h @@ -320,7 +320,7 @@ ffi::Array UnpackedInstTraits::ApplyToSchedule(const Schedule& sch TTraits::template _SetDecision<1 + kNumInputs + kNumAttrs>(packed_args, decision); ffi::Function pf([](const ffi::PackedArgs& args, ffi::Any* rv) -> void { constexpr size_t kNumArgs = details::NumArgs; - ICHECK_EQ(args.size(), kNumArgs); + TVM_FFI_ICHECK_EQ(args.size(), kNumArgs); ffi::details::unpack_call(std::make_index_sequence{}, nullptr, TTraits::UnpackedApplyToSchedule, args.data(), args.size(), rv); @@ -352,7 +352,7 @@ ffi::String UnpackedInstTraits::AsPython(const ffi::Array& inputs, TTraits::template _SetDecision<1 + kNumInputs + kNumAttrs>(packed_args, decision); ffi::Function pf([](const ffi::PackedArgs& args, ffi::Any* rv) -> void { constexpr size_t kNumArgs = details::NumArgs; - ICHECK_EQ(args.size(), kNumArgs); + TVM_FFI_ICHECK_EQ(args.size(), kNumArgs); ffi::details::unpack_call(std::make_index_sequence{}, nullptr, TTraits::UnpackedAsPython, args.data(), args.size(), rv); }); @@ -366,8 +366,8 @@ template TVM_ALWAYS_INLINE void UnpackedInstTraits::_SetInputs(AnyView* packed_args, const ffi::Array& inputs) { constexpr size_t kNumInputs = TTraits::kNumInputs; - ICHECK_EQ(kNumInputs, inputs.size()) - << "ValueError: Incorrect kNumInputs for instruction: " << TTraits::kName; + TVM_FFI_CHECK_EQ(kNumInputs, inputs.size(), ValueError) + << "Incorrect kNumInputs for instruction: " << TTraits::kName; for (size_t i = 0; i < kNumInputs; ++i) { packed_args[i + index_offset] = inputs[i]; } @@ -378,8 +378,8 @@ template TVM_ALWAYS_INLINE void UnpackedInstTraits::_SetAttrs(AnyView* packed_args, const ffi::Array& attrs) { constexpr size_t kNumAttrs = TTraits::kNumAttrs; - ICHECK_EQ(kNumAttrs, attrs.size()) - << "ValueError: Incorrect kNumAttrs for instruction: " << TTraits::kName; + TVM_FFI_CHECK_EQ(kNumAttrs, attrs.size(), ValueError) + << "Incorrect kNumAttrs for instruction: " << TTraits::kName; for (size_t i = 0; i < kNumAttrs; ++i) { packed_args[i + index_offset] = attrs[i]; } @@ -394,7 +394,7 @@ TVM_ALWAYS_INLINE void UnpackedInstTraits::_SetDecision(AnyView* packed if (kNumDecisions == 1) { packed_args[index_offset] = decision; } else { - ICHECK(decision == nullptr); + TVM_FFI_ICHECK(decision == nullptr); } } @@ -463,8 +463,8 @@ inline void PythonAPICall::AsPythonString(const Any& obj, std::ostream& os) { } os << '}'; } else { - LOG(FATAL) << "ValueError: Cannot translate type '" << obj.GetTypeKey() - << "' to python. Its value is: " << obj; + TVM_FFI_THROW(ValueError) << "Cannot translate type '" << obj.GetTypeKey() + << "' to python. Its value is: " << obj; throw; } } @@ -522,7 +522,7 @@ void PythonAPICall::Decision(Any decision) { } void PythonAPICall::SingleOutput(ffi::Array unit_array) { - ICHECK_EQ(unit_array.size(), 1); + TVM_FFI_ICHECK_EQ(unit_array.size(), 1); this->output_ = unit_array[0]; } diff --git a/src/s_tir/schedule/ir_comparator.cc b/src/s_tir/schedule/ir_comparator.cc index c2dc59acfec4..6f69583b88f0 100644 --- a/src/s_tir/schedule/ir_comparator.cc +++ b/src/s_tir/schedule/ir_comparator.cc @@ -35,7 +35,7 @@ class TensorIntrinMismatchError : public ScheduleError { lhs_stmt_(std::move(lhs_stmt)), rhs_stmt_(std::move(rhs_stmt)), error_messages_(std::move(error_messages)) { - ICHECK(lhs_stmt_->IsInstance() || lhs_stmt_->IsInstance()); + TVM_FFI_ICHECK(lhs_stmt_->IsInstance() || lhs_stmt_->IsInstance()); } ffi::String FastErrorString() const final { @@ -464,7 +464,7 @@ bool TensorizeComparator::CompareBufferRegion(const BufferRegion& lhs, const Buf auto it = buffer_indices_.find(lhs->buffer); if (it == buffer_indices_.end()) { // Update base indices for the buffer, this can only happen if it is visiting the scope block. - ICHECK(is_scope_block); + TVM_FFI_ICHECK(is_scope_block); std::vector indices_base; indices_base.reserve(lhs->region.size()); for (int i = 0; i < offset; i++) { @@ -565,9 +565,9 @@ bool TensorizeComparator::CompareBufferAccess(const T* lhs, const T* rhs) { return false; } auto it = buffer_indices_.find(lhs->buffer); - ICHECK(it != buffer_indices_.end()); + TVM_FFI_ICHECK(it != buffer_indices_.end()); const std::vector& indices_base = (*it).second; - ICHECK_EQ(indices_base.size(), rhs->indices.size() + offset); + TVM_FFI_ICHECK_EQ(indices_base.size(), rhs->indices.size() + offset); for (size_t i = 0; i < rhs->indices.size(); i++) { PrimExpr normalized_lhs_index = lhs->indices[i + offset] - indices_base[i + offset]; if (!analyzer_.CanProveEqual(normalized_lhs_index, rhs->indices[i])) { diff --git a/src/s_tir/schedule/primitive/annotate.cc b/src/s_tir/schedule/primitive/annotate.cc index 3584d703371c..0a8d985e2632 100644 --- a/src/s_tir/schedule/primitive/annotate.cc +++ b/src/s_tir/schedule/primitive/annotate.cc @@ -31,7 +31,7 @@ void Annotate(ScheduleState self, const StmtSRef& sref, const ffi::String& ann_k } else if (const auto* block = sref->StmtAs()) { annotations = &block->annotations; } else { - LOG(FATAL) << "TypeError: Unknown type of sref: " << sref->stmt->GetTypeKey(); + TVM_FFI_THROW(TypeError) << "Unknown type of sref: " << sref->stmt->GetTypeKey(); } // Check if the annotation already exists if (annotations->find(ann_key) != annotations->end()) { @@ -51,7 +51,7 @@ void Annotate(ScheduleState self, const StmtSRef& sref, const ffi::String& ann_k SBlock p(n); self->Replace(sref, p, {{ffi::GetRef(block), p}}); } else { - LOG(FATAL) << "TypeError: Unknown type of sref: " << sref->stmt->GetTypeKey(); + TVM_FFI_THROW(TypeError) << "Unknown type of sref: " << sref->stmt->GetTypeKey(); throw; } } @@ -64,11 +64,11 @@ void Unannotate(ScheduleState self, const StmtSRef& sref, const ffi::String& ann } else if (const auto* block = sref->StmtAs()) { annotations = &block->annotations; } else { - LOG(FATAL) << "TypeError: Unknown type of sref: " << sref->stmt->GetTypeKey(); + TVM_FFI_THROW(TypeError) << "Unknown type of sref: " << sref->stmt->GetTypeKey(); } // Remove the annotation - ICHECK(annotations->find(ann_key) != annotations->end()) - << "IndexError: Cannot find annotation key: " << ann_key; + TVM_FFI_CHECK(annotations->find(ann_key) != annotations->end(), IndexError) + << "Cannot find annotation key: " << ann_key; ffi::Map new_ann(*annotations); new_ann.erase(ann_key); // Create the new stmt @@ -82,7 +82,7 @@ void Unannotate(ScheduleState self, const StmtSRef& sref, const ffi::String& ann SBlock p(n); self->Replace(sref, p, {{ffi::GetRef(block), p}}); } else { - LOG(FATAL) << "TypeError: Unknown type of sref: " << sref->stmt->GetTypeKey(); + TVM_FFI_THROW(TypeError) << "Unknown type of sref: " << sref->stmt->GetTypeKey(); throw; } } @@ -104,8 +104,8 @@ struct AnnotateTraits : public UnpackedInstTraits { if (auto loop = block_or_loop_rv.as()) { return sch->Annotate(loop.value(), ann_key, ann_val); } - LOG(FATAL) << "TypeError: Expected SBlock or Loop, but gets: " - << block_or_loop_rv->GetTypeKey(); + TVM_FFI_THROW(TypeError) << "Expected SBlock or Loop, but gets: " + << block_or_loop_rv->GetTypeKey(); throw; } @@ -139,8 +139,8 @@ struct UnannotateTraits : public UnpackedInstTraits { if (auto loop = block_or_loop_rv.as()) { return sch->Unannotate(loop.value(), ann_key); } - LOG(FATAL) << "TypeError: Expected SBlock or Loop, but gets: " - << block_or_loop_rv->GetTypeKey(); + TVM_FFI_THROW(TypeError) << "Expected SBlock or Loop, but gets: " + << block_or_loop_rv->GetTypeKey(); throw; } diff --git a/src/s_tir/schedule/primitive/annotate_buffer_access.cc b/src/s_tir/schedule/primitive/annotate_buffer_access.cc index 5fc00e3c364d..2e77570e4eb5 100644 --- a/src/s_tir/schedule/primitive/annotate_buffer_access.cc +++ b/src/s_tir/schedule/primitive/annotate_buffer_access.cc @@ -36,8 +36,9 @@ class AnnotateRegionRewriter : public StmtExprMutator { ffi::Array regions = buffer_index_type_ == BufferIndexType::kWrite ? block->writes : block->reads; - ICHECK_GE(buffer_index_, 0) << "Buffer index must be non-negative"; - ICHECK_LT(buffer_index_, static_cast(regions.size())) << "Buffer index out of range"; + TVM_FFI_ICHECK_GE(buffer_index_, 0) << "Buffer index must be non-negative"; + TVM_FFI_ICHECK_LT(buffer_index_, static_cast(regions.size())) + << "Buffer index out of range"; regions.Set(buffer_index_, new_region_); ObjectPtr n = CopyOnWrite(block.get()); @@ -93,7 +94,7 @@ void AnnotateBufferAccess(ScheduleState self, const StmtSRef& block_sref, int bu block_iter_vars.push_back(iter_var->var); } ffi::Array new_indices = index_map->MapIndices(block_iter_vars, &analyzer); - ICHECK_EQ(new_indices.size() % 2, 0) << "The size of new_indices should be even."; + TVM_FFI_ICHECK_EQ(new_indices.size() % 2, 0) << "The size of new_indices should be even."; ffi::Array new_ranges; for (size_t i = 0; i < new_indices.size(); i += 2) { // (begin, end) represents a region diff --git a/src/s_tir/schedule/primitive/blockize_tensorize.cc b/src/s_tir/schedule/primitive/blockize_tensorize.cc index f7cfda547910..f9631b96d484 100644 --- a/src/s_tir/schedule/primitive/blockize_tensorize.cc +++ b/src/s_tir/schedule/primitive/blockize_tensorize.cc @@ -217,7 +217,7 @@ ffi::Map DeriveBlockBinding( using arith::IterMapExprNode; using arith::NormalizeIterMapToExpr; ffi::Map block_var_subst; - ICHECK_EQ(iter_vars.size() + 1, division.size()); + TVM_FFI_ICHECK_EQ(iter_vars.size() + 1, division.size()); arith::Analyzer ana; for (int i = 0, n = iter_vars.size(); i < n; ++i) { const IterVar& iter_var = iter_vars[i]; @@ -235,8 +235,8 @@ ffi::Map DeriveBlockBinding( IterVar outer_iter; if (reuse_outer) { outer_iter = outer_iter_vars->operator[](i); - ICHECK(ana.CanProveEqual(outer_iter->dom->extent, outer_mark->extent)); - ICHECK( + TVM_FFI_ICHECK(ana.CanProveEqual(outer_iter->dom->extent, outer_mark->extent)); + TVM_FFI_ICHECK( ana.CanProveEqual(outer_bindings->operator[](i), NormalizeIterMapToExpr(outer_binding))); } else { outer_iter = IterVar(/*dom=*/RangeFromExtent(outer_mark->extent), @@ -320,7 +320,7 @@ Stmt GenerateOuterInit(const Stmt& block_init, const SBlockRealize& inner_realiz // 2) It is used in the original init block ffi::Array iter_vars; ffi::Array iter_values; - ICHECK_EQ(inner_block->iter_vars.size(), inner_realize->iter_values.size()); + TVM_FFI_ICHECK_EQ(inner_block->iter_vars.size(), inner_realize->iter_values.size()); int n = inner_block->iter_vars.size(); iter_vars.reserve(n); iter_values.reserve(n); @@ -427,7 +427,7 @@ ffi::Array EvalSetRegions(const ffi::Array& regions, for (const BufferRegion& buffer_region : regions) { const Buffer& buffer = buffer_region->buffer; ffi::Array relaxed = arith::EvalSet(buffer_region->region, dom_map); - ICHECK_EQ(relaxed.size(), buffer->shape.size()); + TVM_FFI_ICHECK_EQ(relaxed.size(), buffer->shape.size()); int ndim = buffer->shape.size(); ffi::Array new_region; new_region.reserve(ndim); @@ -631,7 +631,7 @@ SBlockRealize BlockizeBlocks(const ScheduleState& self, const ffi::Array(); - ICHECK(seq) << "Target blocks must not be nested with each other!"; + TVM_FFI_ICHECK(seq) << "Target blocks must not be nested with each other!"; int idx_start = -1; int last_found_idx = -1; size_t cur_idx = 0; @@ -689,7 +689,7 @@ class BlockizeRewriter : public StmtMutator { idx_start = cur_idx; new_seq.push_back(blockized_); } else { - ICHECK_EQ(last_found_idx, cur_idx - 1) << "Target blocks must be consecutive!"; + TVM_FFI_ICHECK_EQ(last_found_idx, cur_idx - 1) << "Target blocks must be consecutive!"; } last_found_idx = cur_idx; } else { @@ -757,8 +757,8 @@ void Tensorize(ScheduleState self, const StmtSRef& sref, const TensorIntrin& int ffi::Map block_sref_reuse; block_realize = BlockizeImpl(self, sref, &block_sref_reuse, &analyzer, preserve_unit_iters); } else { - LOG(FATAL) << "TypeError: Tensorize only support For or SBlock, but gets: " - << ffi::GetRef(sref->stmt); + TVM_FFI_THROW(TypeError) << "Tensorize only support For or SBlock, but gets: " + << ffi::GetRef(sref->stmt); throw; } @@ -776,7 +776,7 @@ void Tensorize(ScheduleState self, const StmtSRef& sref, const TensorIntrin& int }; f_update_max_dtype_bits_from_region(block_realize->block->reads); f_update_max_dtype_bits_from_region(block_realize->block->writes); - ICHECK(index_dtype_bits > 0); + TVM_FFI_ICHECK(index_dtype_bits > 0); intrin_impl = IndexDataTypeNormalizer(DataType::Int(index_dtype_bits)).Rewrite(intrin_impl); // Step 2: Structural pattern matching TensorizeComparator comparator(self->mod, /*assert_mode=*/true); @@ -786,7 +786,7 @@ void Tensorize(ScheduleState self, const StmtSRef& sref, const TensorIntrin& int // 2) Buffer mapping from intrin impl buffers to buffers in the current AST. // 3) Mapping impl buffers to their accessed regions. std::unordered_map impl2desc; - ICHECK_EQ(intrin_desc->params.size(), intrin_impl->params.size()); + TVM_FFI_ICHECK_EQ(intrin_desc->params.size(), intrin_impl->params.size()); for (int i = 0, n = intrin_desc->params.size(); i < n; ++i) { const Buffer& desc = intrin_desc->buffer_map[intrin_desc->params[i]]; const Buffer& impl = intrin_impl->buffer_map[intrin_impl->params[i]]; @@ -796,7 +796,7 @@ void Tensorize(ScheduleState self, const StmtSRef& sref, const TensorIntrin& int for (const auto& pair : impl2desc) { const Buffer& impl = pair.first; const Buffer& desc = pair.second; - ICHECK(comparator.rhs_buffer_map_.count(desc)); + TVM_FFI_ICHECK(comparator.rhs_buffer_map_.count(desc)); impl2cur[impl] = comparator.rhs_buffer_map_[desc]; } std::unordered_map, ObjectPtrHash, ObjectPtrEqual> impl2region; @@ -817,7 +817,7 @@ void Tensorize(ScheduleState self, const StmtSRef& sref, const TensorIntrin& int const ffi::Array& old_region = impl2region.at(impl); const std::vector& indices_base = comparator.buffer_indices_.at(cur); int offset = static_cast(indices_base.size()) - static_cast(old_region.size()); - ICHECK(offset >= 0); + TVM_FFI_ICHECK(offset >= 0); ffi::Array new_region; new_region.reserve(cur->shape.size()); for (int i = 0; i < offset; i++) { @@ -876,7 +876,7 @@ struct BlockizeTraits : public UnpackedInstTraits { } else if (auto blocks = target.as>()) { return sch->Blockize(blocks.value(), preserve_unit_iters.operator bool()); } - LOG(FATAL) << "TypeError: expect Loop or list of SBlocks, but gets:" << target->GetTypeKey(); + TVM_FFI_THROW(TypeError) << "expect Loop or list of SBlocks, but gets:" << target->GetTypeKey(); } static ffi::String UnpackedAsPython(ffi::Array outputs, ObjectRef target, @@ -908,8 +908,8 @@ struct TensorizeTraits : public UnpackedInstTraits { } else if (auto loop = block_or_loop_rv.as()) { sch->Tensorize(loop.value(), intrin, preserve_unit_iters.operator bool()); } else { - LOG(FATAL) << "TypeError: Expected SBlock or Loop, but gets: " - << block_or_loop_rv->GetTypeKey(); + TVM_FFI_THROW(TypeError) << "Expected SBlock or Loop, but gets: " + << block_or_loop_rv->GetTypeKey(); } } diff --git a/src/s_tir/schedule/primitive/cache_index.cc b/src/s_tir/schedule/primitive/cache_index.cc index a9f57fe61d9c..0a5bb9a0d293 100644 --- a/src/s_tir/schedule/primitive/cache_index.cc +++ b/src/s_tir/schedule/primitive/cache_index.cc @@ -62,8 +62,8 @@ DataType DetermineDatatype(const arith::IntSet& range) { if (ana.CanProve(range.min() >= INT32_MIN && range.max() <= INT32_MAX)) { return DataType::Int(32); } else { - ICHECK(ana.CanProve(range.min() >= make_const(DataType::Int(64), INT64_MIN) && - range.max() <= make_const(DataType::Int(64), INT64_MAX))); + TVM_FFI_ICHECK(ana.CanProve(range.min() >= make_const(DataType::Int(64), INT64_MIN) && + range.max() <= make_const(DataType::Int(64), INT64_MAX))); return DataType::Int(64); } } @@ -356,7 +356,7 @@ Stmt InsertIndexStage(const Stmt& stmt, int pos, const Stmt& stage) { if (pos == 0) { return SeqStmt::Flatten>({stage, stmt}); } - ICHECK_EQ(pos, 1); + TVM_FFI_ICHECK_EQ(pos, 1); return SeqStmt::Flatten>({stmt, stage}); } @@ -451,7 +451,7 @@ ffi::Array CacheIndex(ScheduleState self, const StmtSRef& block_sref, // Step 0. Checking index, getting the target buffer and the parent scope IndexInfo info; info.target_sblock = block_sref; - CHECK_GE(cse_thresh, 0) << "cse_thresh should not be negative number"; + TVM_FFI_ICHECK_GE(cse_thresh, 0) << "cse_thresh should not be negative number"; info.cse_thresh = cse_thresh; StmtSRef scope_sref = GetScopeRoot(self, block_sref, /*require_stage_pipeline=*/false); diff --git a/src/s_tir/schedule/primitive/cache_read_write.cc b/src/s_tir/schedule/primitive/cache_read_write.cc index 1d80bc893965..5e761ca04a86 100644 --- a/src/s_tir/schedule/primitive/cache_read_write.cc +++ b/src/s_tir/schedule/primitive/cache_read_write.cc @@ -33,7 +33,7 @@ class NotSingleWriteBlock : public ScheduleError { public: explicit NotSingleWriteBlock(IRModule mod, Buffer buffer, ffi::Array write_blocks) : mod_(std::move(mod)), buffer_(std::move(buffer)) { - ICHECK_GT(write_blocks.size(), 1); + TVM_FFI_ICHECK_GT(write_blocks.size(), 1); write_blocks_.reserve(write_blocks.size()); for (const StmtSRef& block_sref : write_blocks) { const SBlockNode* block = TVM_SREF_TO_SBLOCK(block_sref); @@ -92,7 +92,7 @@ ffi::Optional GetBufferRegionFromBuffer( ffi::Optional res = std::nullopt; for (const auto& region : buffer_regions) { if (region->buffer.same_as(buffer)) { - ICHECK(!res.defined()); + TVM_FFI_ICHECK(!res.defined()); res = region; } } @@ -474,8 +474,8 @@ Stmt InsertCacheStage(const Stmt& stmt, int pos, const Stmt& stage) { if (const auto* seq_stmt = body.as()) { ffi::Array seq = seq_stmt->seq; - ICHECK_LE(pos, seq.size()) << "Cannot insert at position " << pos << " into sequence of length " - << seq.size(); + TVM_FFI_ICHECK_LE(pos, seq.size()) + << "Cannot insert at position " << pos << " into sequence of length " << seq.size(); seq.insert(seq.begin() + pos, stage); body = SeqStmt(seq); } else if (pos == 0) { @@ -483,9 +483,9 @@ Stmt InsertCacheStage(const Stmt& stmt, int pos, const Stmt& stage) { } else if (pos == 1) { body = SeqStmt({body, stage}); } else { - LOG(FATAL) << "Cannot insert at position " << pos - << ". When inserting adjacent to non-SeqStmt, " - << "only positions 0 and 1 are valid."; + TVM_FFI_THROW(InternalError) << "Cannot insert at position " << pos + << ". When inserting adjacent to non-SeqStmt, " + << "only positions 0 and 1 are valid."; } body = MergeNest(nest, body); @@ -510,7 +510,7 @@ ffi::Optional GetOnlyWriteBlock(ScheduleState self, const StmtSRef& sc return std::nullopt; } else { const ffi::Array& block_srefs = it->second; - ICHECK(!block_srefs.empty()); + TVM_FFI_ICHECK(!block_srefs.empty()); if (block_srefs.size() > 1) { throw NotSingleWriteBlock(self->mod, buffer, block_srefs); } @@ -534,7 +534,7 @@ bool AllConsumersUnderStmt(ScheduleState self, Buffer buffer, StmtSRef scope_sre std::unordered_set blocks_under_target; for (const StmtSRef& block_sref : GetChildBlocks(self, stmt_sref)) { const auto* block = block_sref->StmtAs(); - ICHECK(block != nullptr); + TVM_FFI_ICHECK(block != nullptr); blocks_under_target.insert(block); } @@ -543,7 +543,7 @@ bool AllConsumersUnderStmt(ScheduleState self, Buffer buffer, StmtSRef scope_sre // target stmt. for (const StmtSRef& block_sref : GetChildBlocks(self, scope_sref)) { const auto* block = block_sref->StmtAs(); - ICHECK(block != nullptr); + TVM_FFI_ICHECK(block != nullptr); if (GetBufferRegionFromBuffer(block->reads, buffer).defined()) { if (blocks_under_target.find(block) == blocks_under_target.end()) { return false; @@ -576,7 +576,7 @@ BufferRegion RelaxBufferRegion(ScheduleState self, const BufferRegion& buffer_re /*dom_low_inclusive=*/dom_low_inclusive, /*dom_high_exclusive=*/dom_high_exclusive, /*analyzer=*/&analyzer); - ICHECK_EQ(buffer_region->region.size(), int_sets.size()); + TVM_FFI_ICHECK_EQ(buffer_region->region.size(), int_sets.size()); Region region; region.reserve(int_sets.size()); @@ -853,7 +853,7 @@ class CacheReadRewriter : public StmtExprMutator { bool cache_full_region = true) : scope_sref_(scope_sref), info_(info), cache_full_region_(cache_full_region) { auto update_region = [this](const Region& region, const Region& offset) -> Region { - ICHECK_EQ(region.size(), offset.size()); + TVM_FFI_ICHECK_EQ(region.size(), offset.size()); std::vector ret; for (size_t i = 0; i < region.size(); ++i) { ret.push_back(Range::FromMinExtent(ana_.Simplify(region[i]->min - offset[i]->min), @@ -1110,7 +1110,7 @@ class CacheWriteRewriter : public StmtExprMutator { info_(info), cache_full_region_(cache_full_region) { auto update_region = [this](const Region& region, const Region& offset) -> Region { - ICHECK_EQ(region.size(), offset.size()); + TVM_FFI_ICHECK_EQ(region.size(), offset.size()); std::vector ret; for (size_t i = 0; i < region.size(); ++i) { ret.push_back(Range::FromMinExtent(ana_.Simplify(region[i]->min - offset[i]->min), @@ -1797,7 +1797,7 @@ StmtSRef CacheWrite(ScheduleState self, const StmtSRef& block_sref, int write_bu } // Step 3. Check the only writer block. - ICHECK_EQ(block_sref.get(), GetOnlyWriteBlock(self, scope_sref, write_buffer).get()); + TVM_FFI_ICHECK_EQ(block_sref.get(), GetOnlyWriteBlock(self, scope_sref, write_buffer).get()); // Step 4. Find the producing region and insert position BufferRegion region = GetBufferRegionFromBuffer(block->writes, write_buffer).value(); @@ -2019,8 +2019,8 @@ StmtSRef ReindexCacheRead(ScheduleState self, const StmtSRef& block_sref, int re // Step 3. Update cache stage info. ffi::Optional maybe_region = GetBufferRegionFromBuffer(block->reads, read_buffer); - ICHECK(maybe_region.defined()) << read_buffer - << " should appear in the block's read region: " << block->reads; + TVM_FFI_ICHECK(maybe_region.defined()) + << read_buffer << " should appear in the block's read region: " << block->reads; BufferRegion cache_region = maybe_region.value(); if (ffi::Optional _write_block_sref = GetOnlyWriteBlock(self, scope_sref, read_buffer)) { @@ -2089,11 +2089,12 @@ StmtSRef ReindexCacheWrite(ScheduleState self, const StmtSRef& block_sref, int w info.write_buffer = write_buffer; // Step 3. Check the only writer block. - ICHECK_EQ(block_sref.get(), GetOnlyWriteBlock(self, scope_sref, write_buffer).get()); + TVM_FFI_ICHECK_EQ(block_sref.get(), GetOnlyWriteBlock(self, scope_sref, write_buffer).get()); // Step 4. Find the producing region and insert position ffi::Optional maybe_region = GetBufferRegionFromBuffer(block->writes, write_buffer); - ICHECK(maybe_region.defined()) << write_buffer << " should appear in the block's write region"; + TVM_FFI_ICHECK(maybe_region.defined()) + << write_buffer << " should appear in the block's write region"; StmtSRef parent_sref = ffi::GetRef(block_sref->parent); // Detect insert position CacheLocDetector::Detect(self, block_sref, scope_sref, &info); @@ -2282,8 +2283,8 @@ StmtSRef ReIndex(ScheduleState self, const StmtSRef& block_sref, int buffer_inde for (loop = block_sref->parent; loop->parent != scope_sref.get();) { const ForNode* outer = loop->parent->StmtAs(); const ForNode* inner = loop->StmtAs(); - ICHECK(outer != nullptr && inner != nullptr); - ICHECK(outer->body.get() == inner); + TVM_FFI_ICHECK(outer != nullptr && inner != nullptr); + TVM_FFI_ICHECK(outer->body.get() == inner); loop = loop->parent; } diff --git a/src/s_tir/schedule/primitive/compute_at.cc b/src/s_tir/schedule/primitive/compute_at.cc index 4f48c1054671..5c476213ddb2 100644 --- a/src/s_tir/schedule/primitive/compute_at.cc +++ b/src/s_tir/schedule/primitive/compute_at.cc @@ -163,7 +163,7 @@ int FindInsertionPoint( } // Step 3. Check if there is at least one index of the position can be inserted into // The valid indices are: (last_producer_position, first_consumer_position] - ICHECK(split.last_producer_position < split.first_consumer_position); + TVM_FFI_ICHECK(split.last_producer_position < split.first_consumer_position); // Step 4. Return the possible insertion point according to index int insert_position; if (index == -1) { @@ -174,9 +174,9 @@ int FindInsertionPoint( index <= split.first_consumer_position) { insert_position = index; } else { - LOG(FATAL) << "Valid index:(-1, -2, [" << split.last_producer_position + 1 << ", " - << split.first_consumer_position << "]), " - << "current index=" << index; + TVM_FFI_THROW(InternalError) << "Valid index:(-1, -2, [" << split.last_producer_position + 1 + << ", " << split.first_consumer_position << "]), " + << "current index=" << index; throw; } return insert_position; @@ -467,7 +467,7 @@ std::pair SolveBlockVarDomain(const arith::IntSet& prov } } } - ICHECK(var.defined()) << "ValueError: BufferRegion pattern match failed: " << provided_min; + TVM_FFI_CHECK(var.defined(), ValueError) << "BufferRegion pattern match failed: " << provided_min; return {var.value(), BlockVarDomainInfo{var_dom, var_bound}}; } @@ -490,8 +490,8 @@ void UpdateBlockVarDomainDimwise( PrimExpr dim_max = max(buffer->shape[i] - 1, 0); if (provided.CanProveSinglePoint(analyzer) && is_const_int(provided.min())) { - ICHECK(required.CanProveSinglePoint(analyzer) && - analyzer->CanProveEqual(provided.min(), required.min())); + TVM_FFI_ICHECK(required.CanProveSinglePoint(analyzer) && + analyzer->CanProveEqual(provided.min(), required.min())); continue; } @@ -500,8 +500,8 @@ void UpdateBlockVarDomainDimwise( if (it != iter_doms->end()) { it->second.Union(dom_info); } else { - ICHECK(analyzer->CanProveEqual(provided.min(), required.min())); - ICHECK(analyzer->CanProveEqual(provided.max(), required.max())); + TVM_FFI_ICHECK(analyzer->CanProveEqual(provided.min(), required.min())); + TVM_FFI_ICHECK(analyzer->CanProveEqual(provided.max(), required.max())); } } } @@ -514,7 +514,7 @@ ffi::Map InverseAffineIterMap(const ffi::Array InverseAffineIterMap(const ffi::Array CalculateBlockVarDomain( } NDIntSet required_region = support::NDIntSetUnion(it->second); NDIntSet provided_region = support::NDIntSetUnion(many_provided_regions); - ICHECK_EQ(provided_region.size(), buffer->shape.size()); - ICHECK_EQ(required_region.size(), buffer->shape.size()); + TVM_FFI_ICHECK_EQ(provided_region.size(), buffer->shape.size()); + TVM_FFI_ICHECK_EQ(required_region.size(), buffer->shape.size()); // Try update iter var domains with current required and provided region pair. if (!UpdateBlockVarDomainAffine(buffer, iter_vars, provided_region, required_region, analyzer, &iter_doms)) { @@ -638,7 +638,7 @@ std::vector CalculateBlockVarDomain( info.bound = arith::Intersect({info.bound, arith::IntSet::FromRange(iter_var->dom)}); } info.Simplify(analyzer); - ICHECK(!info.dom.IsNothing()); + TVM_FFI_ICHECK(!info.dom.IsNothing()); result.push_back(info); } return result; @@ -677,7 +677,7 @@ void CalculateProvidedRequiredRegions( // Step 2. Calculate the region required by dependent blocks under `loop` for (const StmtSRef& required_block_sref : is_compute_at ? consumer_srefs : producer_srefs) { const SBlockNode* required_block = TVM_SREF_TO_SBLOCK(required_block_sref); - ICHECK(block2realize.count(required_block)); + TVM_FFI_ICHECK(block2realize.count(required_block)); RelaxBufferRegions( /*binding=*/GetBindings(ffi::GetRef(block2realize.at(required_block))), /*buffer_regions=*/is_compute_at ? required_block->reads : required_block->writes, diff --git a/src/s_tir/schedule/primitive/compute_inline.cc b/src/s_tir/schedule/primitive/compute_inline.cc index c704fe134aa9..ccc5ea3ccd9c 100644 --- a/src/s_tir/schedule/primitive/compute_inline.cc +++ b/src/s_tir/schedule/primitive/compute_inline.cc @@ -310,7 +310,7 @@ class BaseInliner : public StmtExprMutator { Stmt VisitStmt_(const ForNode* loop) final { if (src_stmt.get() == loop) { loop = tgt_stmt.as(); - ICHECK(loop != nullptr); + TVM_FFI_ICHECK(loop != nullptr); } return StmtExprMutator::VisitStmt_(loop); } @@ -321,7 +321,7 @@ class BaseInliner : public StmtExprMutator { SBlock src_block = ffi::GetRef(block); if (src_block.same_as(src_stmt)) { block = tgt_stmt.as(); - ICHECK(block != nullptr); + TVM_FFI_ICHECK(block != nullptr); } SBlock tgt_block = Downcast(StmtExprMutator::VisitStmt_(block)); bool is_scope_root = src_block.get() == scope_root_sref_->stmt; @@ -546,7 +546,7 @@ class ComputeInliner : public BaseInliner { * \param indices The expressions that the corresponding index variables are replaced to */ void SetIndexSubstitution(const ffi::Array& indices) { - ICHECK_EQ(indices.size(), idx_vars_.size()); + TVM_FFI_ICHECK_EQ(indices.size(), idx_vars_.size()); int n = idx_vars_.size(); for (int i = 0; i < n; ++i) { idx_sub_[idx_vars_[i].get()] = indices[i]; @@ -1255,8 +1255,8 @@ SBlock ReductionEpilogueFuser::CreateFusedReductionBlock( } } - ICHECK_EQ(reduction_data_vars.size(), epilogue_data_vars.size()) - << "ValueError: The number of data parallel iter vars must be the same in the reduction " + TVM_FFI_CHECK_EQ(reduction_data_vars.size(), epilogue_data_vars.size(), ValueError) + << "The number of data parallel iter vars must be the same in the reduction " "and epilogue blocks."; std::unordered_map var_map; diff --git a/src/s_tir/schedule/primitive/decompose_padding.cc b/src/s_tir/schedule/primitive/decompose_padding.cc index e684eb4a069a..e1dbb32f4c60 100644 --- a/src/s_tir/schedule/primitive/decompose_padding.cc +++ b/src/s_tir/schedule/primitive/decompose_padding.cc @@ -173,7 +173,7 @@ class PaddingInfoAnalyzer { if (sum->args.empty()) { region.push_back(Range::FromMinExtent(sum->base, IntImm(sum->base.dtype(), /* value */ 1))); } else { - ICHECK_EQ(sum->args.size(), 1U); + TVM_FFI_ICHECK_EQ(sum->args.size(), 1U); if (!analyzer_->CanProveEqual(sum->args[0]->scale, 1)) { SetError("Strided iteration is not supported"); return {}; @@ -218,7 +218,7 @@ static std::pair CreateConstBlock(const SBlockRealizeNode* }; // create new write region - ICHECK_EQ(block->writes.size(), 1U); + TVM_FFI_ICHECK_EQ(block->writes.size(), 1U); BufferRegion write_region = BufferRegion( block->writes[0]->buffer, block->writes[0]->region.Map([rewrite_expr](const Range& r) { return Range::FromMinExtent(rewrite_expr(r->min), rewrite_expr(r->extent)); @@ -438,7 +438,7 @@ StmtSRef DecomposePaddingImpl(ScheduleState self, const StmtSRef& block_sref, loops.push_back(cur_loop); if (cur_loop.same_as(const_filling_pos)) { - ICHECK(!found_const_filling_pos); + TVM_FFI_ICHECK(!found_const_filling_pos); found_const_filling_pos = true; if (!found_in_bound_filling_pos) { found_in_bound_filling_pos = true; @@ -453,7 +453,7 @@ StmtSRef DecomposePaddingImpl(ScheduleState self, const StmtSRef& block_sref, } } } - ICHECK(in_bound_filling_pos.defined()); + TVM_FFI_ICHECK(in_bound_filling_pos.defined()); if (!found_const_filling_pos) { throw LoopPositionError(self->mod, const_filling_pos, ffi::GetRef(block), "decompose_padding"); diff --git a/src/s_tir/schedule/primitive/for_kind.cc b/src/s_tir/schedule/primitive/for_kind.cc index 90ec40b05712..fe9ae79893f9 100644 --- a/src/s_tir/schedule/primitive/for_kind.cc +++ b/src/s_tir/schedule/primitive/for_kind.cc @@ -88,7 +88,7 @@ void CheckLoopParallelizableInBlock(const ScheduleState& self, ForKind for_kind, // CheckAffineBinding(self, block); // Cond 2. For each block iter whose binding contains `loop_var`, only two cases are allowed. - ICHECK_EQ(block->iter_vars.size(), block_realize->iter_values.size()); + TVM_FFI_ICHECK_EQ(block->iter_vars.size(), block_realize->iter_values.size()); int n_iters = static_cast(block->iter_vars.size()); for (int i = 0; i < n_iters; ++i) { const IterVar& iter_var = block->iter_vars[i]; diff --git a/src/s_tir/schedule/primitive/get_block_loop.cc b/src/s_tir/schedule/primitive/get_block_loop.cc index bf13bb7795cd..b2427851f2b5 100644 --- a/src/s_tir/schedule/primitive/get_block_loop.cc +++ b/src/s_tir/schedule/primitive/get_block_loop.cc @@ -32,7 +32,7 @@ ffi::Array GetSBlocks(const ScheduleState& self, const ffi::String& na void VisitStmt_(const SBlockNode* block) override { if (block->name_hint == name_) { auto it = self_->stmt2ref.find(block); - ICHECK(it != self_->stmt2ref.end()); + TVM_FFI_ICHECK(it != self_->stmt2ref.end()); results_.push_back(it->second); } StmtVisitor::VisitStmt_(block); @@ -164,8 +164,8 @@ struct GetChildBlocksTraits : public UnpackedInstTraits { if (auto loop = block_or_loop_rv.as()) { return sch->GetChildBlocks(loop.value()); } - LOG(FATAL) << "TypeError: Expected SBlock or Loop, but gets: " - << block_or_loop_rv->GetTypeKey(); + TVM_FFI_THROW(TypeError) << "Expected SBlock or Loop, but gets: " + << block_or_loop_rv->GetTypeKey(); throw; } diff --git a/src/s_tir/schedule/primitive/hide_buffer_access.cc b/src/s_tir/schedule/primitive/hide_buffer_access.cc index 40d116b4f7d2..08482525f9df 100644 --- a/src/s_tir/schedule/primitive/hide_buffer_access.cc +++ b/src/s_tir/schedule/primitive/hide_buffer_access.cc @@ -128,7 +128,7 @@ void UnsafeHideBufferAccess(ScheduleState self, const StmtSRef& block_sref, } reads = block->reads; } else { - CHECK(false) << "Unrecognized buffer type " << buf_type << ", only support read/write"; + TVM_FFI_ICHECK(false) << "Unrecognized buffer type " << buf_type << ", only support read/write"; } /* Step 1: Replace old block with the new block */ diff --git a/src/s_tir/schedule/primitive/layout_transformation.cc b/src/s_tir/schedule/primitive/layout_transformation.cc index 11a7903851fc..f608a4b0a3ff 100644 --- a/src/s_tir/schedule/primitive/layout_transformation.cc +++ b/src/s_tir/schedule/primitive/layout_transformation.cc @@ -96,7 +96,7 @@ class TransformLayoutPlanner : private StmtExprVisitor { static TransformPlan Plan(SBlock block, Buffer old_buffer, Buffer new_buffer, IndexMap index_map, IndexMap inverse, PrimExpr padding_predicate, ffi::Optional pad_value, arith::Analyzer* analyzer) { - ICHECK(!pad_value.defined() || pad_value.value()->final_indices.size() == 1) + TVM_FFI_ICHECK(!pad_value.defined() || pad_value.value()->final_indices.size() == 1) << "Internal error: Should be caught by ScheduleError checks prior to this point"; TransformLayoutPlanner visitor(old_buffer); visitor(block); @@ -163,8 +163,8 @@ class TransformLayoutPlanner : private StmtExprVisitor { if (loop_dependency_range) { size_t i = loop_dependency_range.value().first; size_t j = loop_dependency_range.value().second; - ICHECK_LT(i, active_loops_.size()); - ICHECK_LT(j, active_loops_.size()); + TVM_FFI_ICHECK_LT(i, active_loops_.size()); + TVM_FFI_ICHECK_LT(j, active_loops_.size()); write_info.dependent_loopnest = {active_loops_.begin() + i, active_loops_.begin() + j + 1}; } @@ -231,7 +231,7 @@ class TransformLayoutPlanner : private StmtExprVisitor { pad_value(pad_value), new_block_to_old(*new_block_to_old), analyzer(analyzer) { - ICHECK_EQ(info.dependent_loopnest.size(), inverse->final_indices.size()); + TVM_FFI_ICHECK_EQ(info.dependent_loopnest.size(), inverse->final_indices.size()); for (size_t i = 0; i < info.dependent_loopnest.size(); i++) { Var var = info.dependent_loopnest[i]->loop_var; PrimExpr expr = inverse->final_indices[i]; @@ -305,7 +305,7 @@ class TransformLayoutPlanner : private StmtExprVisitor { new_iter_values.push_back(block_realize->iter_values[i]); } - ICHECK_EQ(new_indices.size(), new_buffer->shape.size()); + TVM_FFI_ICHECK_EQ(new_indices.size(), new_buffer->shape.size()); for (size_t i = 0; i < new_indices.size(); i++) { Var var = inverse->initial_indices[i]; Var virtual_var = new_indices[i]; @@ -321,7 +321,7 @@ class TransformLayoutPlanner : private StmtExprVisitor { new_iter_values.push_back(block_realize->iter_values[i]); } - ICHECK_EQ(inverse->final_indices.size(), old_indices.size()); + TVM_FFI_ICHECK_EQ(inverse->final_indices.size(), old_indices.size()); for (size_t i = 0; i < old_indices.size(); i++) { Var var = Downcast(old_indices[i]); PrimExpr expr = Substitute(inverse->final_indices[i], loop_var_to_virtual_var); @@ -342,7 +342,7 @@ class TransformLayoutPlanner : private StmtExprVisitor { const ffi::Array& old_indices = info.store->indices; - ICHECK_EQ(old_indices.size(), op->indices.size()); + TVM_FFI_ICHECK_EQ(old_indices.size(), op->indices.size()); ExprDeepEqual expr_equal; for (size_t i = 0; i < old_indices.size(); i++) { if (!expr_equal(old_indices[i], op->indices[i])) { @@ -409,7 +409,7 @@ class TransformLayoutPlanner : private StmtExprVisitor { return; } - ICHECK(!new_block_to_old.count(after)); + TVM_FFI_ICHECK(!new_block_to_old.count(after)); while (true) { if (auto opt = new_block_to_old.Get(before)) { @@ -469,7 +469,7 @@ class TransformLayoutPlanner : private StmtExprVisitor { ffi::Array iter_values; ffi::Array indices; ffi::Map loop_indices_to_block_indices; - ICHECK_EQ(inverse->initial_indices.size(), new_buffer->shape.size()); + TVM_FFI_ICHECK_EQ(inverse->initial_indices.size(), new_buffer->shape.size()); for (size_t i = 0; i < inverse->initial_indices.size(); i++) { const auto& loop_var = inverse->initial_indices[i]; const auto& dim = new_buffer->shape[i]; @@ -524,7 +524,7 @@ class TransformLayoutPlanner : private StmtExprVisitor { return std::nullopt; } - ICHECK_EQ(inverse->initial_indices.size(), new_buffer->shape.size()); + TVM_FFI_ICHECK_EQ(inverse->initial_indices.size(), new_buffer->shape.size()); for (size_t rev_i = 0; rev_i < inverse->initial_indices.size(); rev_i++) { size_t i = (inverse->initial_indices.size() - 1) - rev_i; Var loop_var = inverse->initial_indices[i]; @@ -563,7 +563,7 @@ class TransformLayoutPlanner : private StmtExprVisitor { ffi::Array iter_vars; ffi::Array iter_values; ffi::Array indices; - ICHECK_EQ(inverse->initial_indices.size(), new_buffer->shape.size()); + TVM_FFI_ICHECK_EQ(inverse->initial_indices.size(), new_buffer->shape.size()); for (size_t i = 0; i < inverse->initial_indices.size(); i++) { const auto& loop_var = inverse->initial_indices[i]; const auto& dim = new_buffer->shape[i]; @@ -583,7 +583,7 @@ class TransformLayoutPlanner : private StmtExprVisitor { stmt = SBlockRealize(iter_values, padding_predicate, SBlock(iter_vars, {}, {write_region}, block_name.str(), stmt)); - ICHECK_EQ(inverse->initial_indices.size(), new_buffer->shape.size()); + TVM_FFI_ICHECK_EQ(inverse->initial_indices.size(), new_buffer->shape.size()); for (size_t rev_i = 0; rev_i < inverse->initial_indices.size(); rev_i++) { size_t i = (inverse->initial_indices.size() - 1) - rev_i; Var loop_var = inverse->initial_indices[i]; @@ -598,7 +598,7 @@ class TransformLayoutPlanner : private StmtExprVisitor { } else if (info.innermost_block_realize) { return info.innermost_block_realize.value(); } else { - LOG(FATAL) << "Write occured outside of any block/loop"; + TVM_FFI_THROW(InternalError) << "Write occured outside of any block/loop"; } }(); return EpiloguePlan{insert_after, stmt}; @@ -659,7 +659,7 @@ class TransformLayoutPlanner : private StmtExprVisitor { struct BindBlockRealize { BindBlockRealize(TransformLayoutPlanner* self, SBlockRealize block_realize) : self_(self) { - ICHECK_EQ(block_realize->iter_values.size(), block_realize->block->iter_vars.size()); + TVM_FFI_ICHECK_EQ(block_realize->iter_values.size(), block_realize->block->iter_vars.size()); for (size_t i = 0; i < block_realize->iter_values.size(); i++) { bound_vars_.emplace_back(self, block_realize->block->iter_vars[i]->var, block_realize->iter_values[i]); @@ -859,7 +859,7 @@ class TransformLayoutRewriter : private arith::IRMutatorWithAnalyzer { const ffi::Array& infered_access_regions) { auto fmutate = [this, &infered_access_regions](const BufferRegion& buffer_region) { if (buffer_region->buffer.same_as(old_buffer_)) { - ICHECK(infered_access_regions.size() == 1); + TVM_FFI_ICHECK(infered_access_regions.size() == 1); return infered_access_regions[0]; } return buffer_region; @@ -903,7 +903,7 @@ class TransformLayoutRewriter : private arith::IRMutatorWithAnalyzer { return; } - ICHECK(!new_block_to_old_.count(after)); + TVM_FFI_ICHECK(!new_block_to_old_.count(after)); while (true) { if (auto opt = new_block_to_old_.Get(before)) { @@ -980,7 +980,7 @@ class TransformationPaddingTypeError : public ScheduleError { public: TransformationPaddingTypeError(IRModule mod, Buffer buffer, IndexMap pad_value) : mod_(mod), buffer_(buffer), pad_value_(pad_value) { - ICHECK_EQ(pad_value_->final_indices.size(), 1); + TVM_FFI_ICHECK_EQ(pad_value_->final_indices.size(), 1); pad_value_dtype_ = pad_value_->final_indices[0].dtype(); } @@ -1012,7 +1012,7 @@ class TransformationPaddingExpressionError : public ScheduleError { public: static void Check(IRModule mod, Buffer buffer, IndexMap pad_value) { Visitor visitor(buffer); - ICHECK_EQ(pad_value->final_indices.size(), 1) + TVM_FFI_ICHECK_EQ(pad_value->final_indices.size(), 1) << "Internal error: Should be caught by ScheduleError checks prior to this point"; visitor(pad_value->final_indices[0]); if (visitor.illegal_load) { @@ -1102,7 +1102,7 @@ class TransformationIntroducesPaddingError : public ScheduleError { // dtype-mismatch issues later. IndexMap LegalizeIndexMapDType(const IndexMap& index_map, const ffi::Array& args) { const auto& initial_indices_orig = index_map->initial_indices; - ICHECK(args.size() == initial_indices_orig.size()); + TVM_FFI_ICHECK(args.size() == initial_indices_orig.size()); ffi::Array initial_indices; ffi::Map var_map; @@ -1110,7 +1110,7 @@ IndexMap LegalizeIndexMapDType(const IndexMap& index_map, const ffi::Arraydtype) + TVM_FFI_ICHECK_EQ(*index_dtype, args[i]->dtype) << "Buffer index " << args[i] << " has dtype " << args[i]->dtype << ", but previous index for the same buffer access used index type " << *index_dtype; } else { @@ -1129,7 +1129,7 @@ IndexMap LegalizeIndexMapDType(const IndexMap& index_map, const ffi::Arrayfinal_indices.Map([&](PrimExpr index) { if (auto* ptr = index.as()) { - ICHECK(index_dtype.has_value()); + TVM_FFI_ICHECK(index_dtype.has_value()); return tir::make_const(*index_dtype, ptr->value); } else { return SubstituteWithDataTypeLegalization(index, @@ -1386,7 +1386,7 @@ void TransformBlockLayout(ScheduleState self, const StmtSRef& block_sref, block_vars.push_back(iter_var->var); block_iter_dom.Set(iter_var->var, iter_var->dom); block_iter_type[iter_var->var.get()] = iter_var->iter_type; - ICHECK(is_zero(iter_var->dom->min)); + TVM_FFI_ICHECK(is_zero(iter_var->dom->min)); block_iter_range_array.push_back(iter_var->dom->extent); } @@ -1470,8 +1470,8 @@ void TransformBlockLayout(ScheduleState self, const StmtSRef& block_sref, // Step 6: Do the actual replacement if (scope_sref->StmtAs()) { - ICHECK(new_loop_vars.empty()) << "Invalid block to loop replacement due to layout transform " - << index_map; + TVM_FFI_ICHECK(new_loop_vars.empty()) + << "Invalid block to loop replacement due to layout transform " << index_map; } self->Replace(scope_sref, body, {{block, new_block}}); } diff --git a/src/s_tir/schedule/primitive/loop_transformation.cc b/src/s_tir/schedule/primitive/loop_transformation.cc index 36f90ce60614..702a32852378 100644 --- a/src/s_tir/schedule/primitive/loop_transformation.cc +++ b/src/s_tir/schedule/primitive/loop_transformation.cc @@ -348,7 +348,7 @@ class LoopsNotAChainError : public ScheduleError { if (kind_ == ProblemKind::kNotUnderAScope) { return {}; } else { - ICHECK(problematic_loop_.defined()); + TVM_FFI_ICHECK(problematic_loop_.defined()); return {problematic_loop_.value()}; } } @@ -842,18 +842,18 @@ StmtSRef Merge(ScheduleState self, const ffi::Array& loop_srefs) { nest_loop_extents = nest_loop_i_extents; } else { if (scope_root_sref_.get() != scope_root_sref.get()) { - LOG(FATAL) << "ScheduleError: Expected the loops to be under the same block scope."; + TVM_FFI_THROW(ScheduleError) << "Expected the loops to be under the same block scope."; throw; } if (nest_loop_i_extents.size() != nest_loop_extents.size()) { - LOG(FATAL) << "ScheduleError: Merge loop's nesting depth must be same, but not."; + TVM_FFI_THROW(ScheduleError) << "Merge loop's nesting depth must be same, but not."; throw; } else { for (size_t j = 0; j < nest_loop_i_extents.size(); j++) { if (!analyzer.CanProveEqual(nest_loop_i_extents[j], nest_loop_extents[j])) { - LOG(FATAL) << "ScheduleError: Merge loop's `extent` must be same, but not." - << " extent=[" << j << "," << nest_loop_extents[j] << "," - << nest_loop_i_extents[j] << "]"; + TVM_FFI_THROW(ScheduleError) << "Merge loop's `extent` must be same, but not." + << " extent=[" << j << "," << nest_loop_extents[j] << "," + << nest_loop_i_extents[j] << "]"; throw; } } @@ -1048,7 +1048,7 @@ std::vector GetLoopsInReorderRange(const ScheduleState& sel const StmtSRefNode* parent_loop_sref = loop_sref->parent; const ForNode* outer = parent_loop_sref->StmtAs(); const ForNode* inner = loop_sref->StmtAs(); - ICHECK(outer != nullptr && inner != nullptr); + TVM_FFI_ICHECK(outer != nullptr && inner != nullptr); if (outer->body.get() != inner) { throw LoopsNotAChainError(self->mod, ffi::GetRef(outer), LoopsNotAChainError::ProblemKind::kHaveNonSingleBranchStmt); @@ -1085,7 +1085,7 @@ For ConstructNewLoopChain(const ScheduleState& self, std::vectorStmtAs(); } - ICHECK(copy != nullptr); + TVM_FFI_ICHECK(copy != nullptr); ObjectPtr n = ffi::make_object(*copy); if (new_loop.defined()) { n->body = new_loop; @@ -1160,7 +1160,7 @@ StmtSRef AddUnitLoop(ScheduleState self, StmtSRef sref) { For new_loop_{nullptr}; }; - CHECK(sref->parent != nullptr) << "ValueError: Cannot add loops on top of the root block"; + TVM_FFI_CHECK(sref->parent != nullptr, ValueError) << "Cannot add loops on top of the root block"; StmtSRef parent_sref = ffi::GetRef(sref->parent); NewLoopCreator creator(sref->stmt); Stmt new_stmt = creator(ffi::GetRef(parent_sref->stmt)); @@ -1370,7 +1370,7 @@ struct AddUnitLoopTraits : public UnpackedInstTraits { } else if (auto loop = rv.as()) { return sch->AddUnitLoop(loop.value()); } else { - LOG(FATAL) << "TypeError: AddUnitLoop expects a loop or block"; + TVM_FFI_THROW(TypeError) << "AddUnitLoop expects a loop or block"; throw; } } diff --git a/src/s_tir/schedule/primitive/pad_einsum.cc b/src/s_tir/schedule/primitive/pad_einsum.cc index bffb3e6da659..5ddeabf2e5e2 100644 --- a/src/s_tir/schedule/primitive/pad_einsum.cc +++ b/src/s_tir/schedule/primitive/pad_einsum.cc @@ -144,7 +144,7 @@ struct BufferPadding { int ndim = buffer_region->region.size(); for (int i = 0; i < ndim; ++i) { PrimExpr pos = buffer_region->region[i]->min; - ICHECK(pos->IsInstance() || pos->IsInstance()); + TVM_FFI_ICHECK(pos->IsInstance() || pos->IsInstance()); if (pos->IsInstance()) { shape.push_back(IntImm(pos->dtype, 1)); } else if (ffi::Optional extent = iter_extents.Get(Downcast(pos))) { @@ -428,7 +428,7 @@ void PadEinsum(ScheduleState self, const StmtSRef& block_sref, const ffi::Array< break; } } - ICHECK_NE(pos, -1); + TVM_FFI_ICHECK_NE(pos, -1); // Step 5. For each buffer, if it needs padding, create a new buffer and a new block ffi::Array read_blocks; ffi::Array write_blocks; diff --git a/src/s_tir/schedule/primitive/read_write_at.cc b/src/s_tir/schedule/primitive/read_write_at.cc index 8b55141689be..8231876bb922 100644 --- a/src/s_tir/schedule/primitive/read_write_at.cc +++ b/src/s_tir/schedule/primitive/read_write_at.cc @@ -217,18 +217,18 @@ struct ReadWriteAtImpl { // Step 2. Calculate `insert_pos` and [st, ed) for buffer replacement int insert_pos = -1, st = -1, ed = -1; if (is_read) { - ICHECK(!r_pos.empty()); + TVM_FFI_ICHECK(!r_pos.empty()); // No write after the first read - ICHECK(w_pos.empty() || w_pos.back() < r_pos.front()); + TVM_FFI_ICHECK(w_pos.empty() || w_pos.back() < r_pos.front()); // Can be inserted at [0, r_pos.front()], i.e. before the first read insert_pos = r_pos.front(); // Buffer reads in [insert_pos, +oo) is rewritten st = insert_pos; ed = n_subtrees; } else { - ICHECK(!w_pos.empty()); + TVM_FFI_ICHECK(!w_pos.empty()); // No read after the last write - ICHECK(r_pos.empty() || r_pos.back() <= w_pos.back()); + TVM_FFI_ICHECK(r_pos.empty() || r_pos.back() <= w_pos.back()); // Can be inserted into (w_pos.back(), +oo), i.e. after the last write insert_pos = w_pos.back() + 1; st = 0; diff --git a/src/s_tir/schedule/primitive/reduction.cc b/src/s_tir/schedule/primitive/reduction.cc index 6e54f928c908..087eabbce812 100644 --- a/src/s_tir/schedule/primitive/reduction.cc +++ b/src/s_tir/schedule/primitive/reduction.cc @@ -169,7 +169,7 @@ PrimExpr RemakePredicate(PrimExpr pred, const std::unordered_set if (!UsesVar(lhs.Eval(), f)) new_pred = new_pred && (lhs.Eval() < rhs.Eval()); break; } else { - ICHECK(false) << "Unexpected predicate for reduction block"; + TVM_FFI_ICHECK(false) << "Unexpected predicate for reduction block"; } } return new_pred; @@ -534,7 +534,7 @@ class LoopPropertyError : public ScheduleError { return "ScheduleError: A loop who has extent greater than one and is not bound to any " "block iter should not appear under a reduction loop"; } - ICHECK(false) << "Unreachable"; + TVM_FFI_ICHECK(false) << "Unreachable"; throw; } @@ -553,7 +553,7 @@ class LoopPropertyError : public ScheduleError { return "The loop {0} has extent greater than one, and is not bound to any block iter. " "Therefore it shouldn't appear under a reduction loop"; } - ICHECK(false) << "Unreachable"; + TVM_FFI_ICHECK(false) << "Unreachable"; throw; } @@ -873,7 +873,7 @@ class RFactorBlockCreator : public BaseBlockCreator { iter_values_.push_back(old_binding); return; } - ICHECK(old_iter->iter_type == kCommReduce); + TVM_FFI_ICHECK(old_iter->iter_type == kCommReduce); // This block iter is a reduction block iter that touches the rfactor loop. So next we try to // create a new block iter for all loop vars that appear in the old binding. ffi::Array vars_in_old_binding = UndefinedVars(old_binding); @@ -930,7 +930,7 @@ class RFactorBlockCreator : public BaseBlockCreator { Range::FromMinExtent(additional_iter_->var, make_const(additional_iter_->var.dtype(), 1))); ffi::Optional rf_buffer = buffer_map.Get(write_region->buffer); - ICHECK(rf_buffer.defined()); + TVM_FFI_ICHECK(rf_buffer.defined()); write_regions_.push_back(BufferRegion(rf_buffer.value(), Substitute(region, var_map_))); } } @@ -1017,7 +1017,7 @@ class WriteBackBlockCreator : public BaseBlockCreator { ffi::Array& buf_regions = is_read ? read_regions_ : write_regions_; for (const PrimExpr& expr : buf_loads) { const auto* buf_load = expr.as(); - ICHECK(buf_load != nullptr); + TVM_FFI_ICHECK(buf_load != nullptr); ffi::Array region; region.reserve(buf_load->indices.size()); for (const PrimExpr& index : buf_load->indices) { @@ -1159,7 +1159,7 @@ class BlockReplacer : public StmtMutator { Stmt VisitStmt_(const SBlockRealizeNode* block_realize) final { // Due to the visitor's behavior on ForNode, this block-realize must be the reduction block's // block-realize. And we directly return the new `wb_block_realize`. - ICHECK_EQ(block_realize, old_block_realize_.get()); + TVM_FFI_ICHECK_EQ(block_realize, old_block_realize_.get()); return wb_block_realize_; } diff --git a/src/s_tir/schedule/primitive/rolling_buffer.cc b/src/s_tir/schedule/primitive/rolling_buffer.cc index c5c41262f243..ccf85f894b21 100644 --- a/src/s_tir/schedule/primitive/rolling_buffer.cc +++ b/src/s_tir/schedule/primitive/rolling_buffer.cc @@ -267,7 +267,7 @@ class RollingBufferRewriter : public StmtExprMutator { const ffi::Array& infered_access_regions) { auto fmutate = [this, &infered_access_regions](const BufferRegion& buffer_region) { if (buffer_region->buffer.same_as(info_->old_buffer)) { - ICHECK(infered_access_regions.size() == 1); + TVM_FFI_ICHECK(infered_access_regions.size() == 1); return infered_access_regions[0]; } return buffer_region; diff --git a/src/s_tir/schedule/primitive/sampling.cc b/src/s_tir/schedule/primitive/sampling.cc index 40bfff408aca..94f2784e13f9 100644 --- a/src/s_tir/schedule/primitive/sampling.cc +++ b/src/s_tir/schedule/primitive/sampling.cc @@ -74,7 +74,7 @@ struct PrimeTable { } } } - ICHECK_EQ(static_cast(primes.size()), static_cast(kNumPrimes)); + TVM_FFI_ICHECK_EQ(static_cast(primes.size()), static_cast(kNumPrimes)); // Calculate the power table for each prime number pow_tab.reserve(primes.size()); for (int32_t prime : primes) { @@ -127,8 +127,8 @@ struct PrimeTable { int32_t SampleInt(support::LinearCongruentialEngine::TRandState* rand_state, int32_t min_inclusive, int32_t max_exclusive) { - CHECK(min_inclusive < max_exclusive) - << "ValueError: max_exclusive must be greater than min_inclusive."; + TVM_FFI_CHECK(min_inclusive < max_exclusive, ValueError) + << "max_exclusive must be greater than min_inclusive."; if (min_inclusive + 1 == max_exclusive) { return min_inclusive; } @@ -166,21 +166,21 @@ std::vector SampleWithoutReplacement( int64_t SampleCategorical(support::LinearCongruentialEngine::TRandState* rand_state, const ffi::Array& candidates, const ffi::Array& probs, ffi::Optional* decision) { - CHECK(candidates.size() == probs.size()) - << "ValueError: number of candidates does not match number of probabilities."; + TVM_FFI_CHECK(candidates.size() == probs.size(), ValueError) + << "number of candidates does not match number of probabilities."; int32_t i = -1; int32_t n = candidates.size(); if (decision->defined()) { i = decision->value()->value; - CHECK(0 <= i && i < n) << "ValueError: Wrong decision value, where n = " << n - << ", but decision is: " << i; + TVM_FFI_CHECK(0 <= i && i < n, ValueError) + << "Wrong decision value, where n = " << n << ", but decision is: " << i; } else { std::vector weights = support::AsVector(probs); std::discrete_distribution dist(weights.begin(), weights.end()); support::LinearCongruentialEngine rand_(rand_state); i = dist(rand_); - ICHECK(0 <= i && i < n) << "ValueError: Unexpected decision generated, where n = " << n - << ", but decision is: " << i; + TVM_FFI_CHECK(0 <= i && i < n, ValueError) + << "Unexpected decision generated, where n = " << n << ", but decision is: " << i; } *decision = Integer(i); // decision is guaranteed not to be nullptr. @@ -189,7 +189,7 @@ int64_t SampleCategorical(support::LinearCongruentialEngine::TRandState* rand_st std::function MakeMultinomialSampler( support::LinearCongruentialEngine::TRandState* rand_state, const std::vector& weights) { - ICHECK(!weights.empty()); + TVM_FFI_ICHECK(!weights.empty()); std::vector sums; sums.reserve(weights.size()); double sum = 0.0; @@ -203,16 +203,16 @@ std::function MakeMultinomialSampler( double p = dist(rand_); int32_t idx = std::lower_bound(sums.begin(), sums.end(), p) - sums.begin(); int32_t n = sums.size(); - CHECK_LE(0, idx); - CHECK_LE(idx, n); + TVM_FFI_ICHECK_LE(0, idx); + TVM_FFI_ICHECK_LE(idx, n); return (idx == n) ? (n - 1) : idx; }; } std::vector SamplePerfectTile(support::LinearCongruentialEngine::TRandState* rand_state, int32_t extent, int32_t n_splits) { - CHECK_GE(extent, 1) << "ValueError: Cannot tile a loop with 0 or negative extent"; - CHECK_GE(n_splits, 1) << "ValueError: Cannot tile a loop to 0 or negative splits"; + TVM_FFI_CHECK_GE(extent, 1, ValueError) << "Cannot tile a loop with 0 or negative extent"; + TVM_FFI_CHECK_GE(n_splits, 1, ValueError) << "Cannot tile a loop to 0 or negative splits"; // Handle special case that we can potentially accelerate if (n_splits == 1) { return {extent}; @@ -298,7 +298,7 @@ std::vector SamplePerfectTile(support::LinearCongruentialEngine::TRandS if (max_innermost_factor == -1) { return SamplePerfectTile(rand_state, extent, n_splits); } - CHECK_GE(n_splits, 2) << "ValueError: Cannot tile a loop into " << n_splits << " splits"; + TVM_FFI_CHECK_GE(n_splits, 2, ValueError) << "Cannot tile a loop into " << n_splits << " splits"; while (true) { std::vector result = SamplePerfectTile(rand_state, extent, n_splits); if (result.back() <= max_innermost_factor) { @@ -322,7 +322,7 @@ std::vector SamplePerfectTile( // Case 2. Use previous decision result = support::AsVector(decision->value()); int n = result.size(); - ICHECK_GE(n, 2); + TVM_FFI_ICHECK_GE(n, 2); int64_t len = *extent; for (int i = n - 1; i > 0; --i) { int64_t& l = result[i]; @@ -339,7 +339,7 @@ std::vector SamplePerfectTile( // Case 3. Use fresh new sampling result result = SamplePerfectTile(rand_state, *extent, n_splits, max_innermost_factor); if (max_innermost_factor != -1) { - ICHECK_LE(result.back(), max_innermost_factor); + TVM_FFI_ICHECK_LE(result.back(), max_innermost_factor); } } *decision = support::AsArray(result); @@ -352,7 +352,7 @@ TVM_DLL std::vector SamplePartitionedTile( if (partition_pos == 0 && innerpart_factor == 1) { return SamplePerfectTile(rand_state, extent, n_splits); } - CHECK_GE(n_splits, 2) << "ValueError: Cannot tile a loop into " << n_splits << " splits"; + TVM_FFI_CHECK_GE(n_splits, 2, ValueError) << "Cannot tile a loop into " << n_splits << " splits"; auto judge = [&](const std::vector& tile) { int64_t prod = 1; for (int i = partition_pos; i < n_splits; ++i) { @@ -383,7 +383,7 @@ std::vector SamplePartitionedTile( // Case 2. Use previous decision result = support::AsVector(decision->value()); int n = result.size(); - ICHECK_GE(n, 2); + TVM_FFI_ICHECK_GE(n, 2); int innerpart_prod = 1; for (int i = partition_pos; i < n; ++i) { innerpart_prod *= result[i]; @@ -423,7 +423,7 @@ tir::StmtSRef SampleComputeLocation(s_tir::ScheduleState self, const StmtSRef& block_sref, ffi::Optional* decision) { // Step 1. Collect all possible compute-at locations. auto [location_srefs, location_indices] = CollectComputeLocation(self, block_sref); - ICHECK_EQ(location_srefs.size(), location_indices.size()); + TVM_FFI_ICHECK_EQ(location_srefs.size(), location_indices.size()); // Step 2. If there was a previous decision, keep the decision unchanged if it exists in the // location candidates. Otherwise, pick the location before the previous decision. diff --git a/src/s_tir/schedule/schedule.cc b/src/s_tir/schedule/schedule.cc index 39a01ff3e7af..4ca8d4d15ed3 100644 --- a/src/s_tir/schedule/schedule.cc +++ b/src/s_tir/schedule/schedule.cc @@ -42,7 +42,7 @@ StmtSRef ScheduleNode::GetSRef(const StmtNode* stmt) const { ScheduleState state = this->state(); auto it = state->stmt2ref.find(stmt); if (it == state->stmt2ref.end()) { - LOG(FATAL) << "IndexError: The stmt doesn't exist in the IR"; + TVM_FFI_THROW(IndexError) << "The stmt doesn't exist in the IR"; } return it->second; } @@ -101,8 +101,9 @@ TVM_FFI_STATIC_INIT_BLOCK() { if (auto expr_rv = obj.as()) { return self->Get(expr_rv.value()); } - LOG(FATAL) << "TypeError: Cannot evaluate the random variable of type: " - << obj->GetTypeKey() << ". Its value is: " << obj; + TVM_FFI_THROW(TypeError) + << "Cannot evaluate the random variable of type: " << obj->GetTypeKey() + << ". Its value is: " << obj; throw; }) .def("s_tir.schedule.ScheduleGetSRef", @@ -116,7 +117,7 @@ TVM_FFI_STATIC_INIT_BLOCK() { if (auto stmt = obj.as()) { return self->GetSRef(stmt.value()); } - LOG(FATAL) << "TypeError: Invalid type: " << obj->GetTypeKey(); + TVM_FFI_THROW(TypeError) << "Invalid type: " << obj->GetTypeKey(); throw; }) .def("s_tir.schedule.ScheduleRemoveRV", [](Schedule self, ObjectRef obj) -> void { @@ -129,7 +130,7 @@ TVM_FFI_STATIC_INIT_BLOCK() { if (auto expr_rv = obj.as()) { return self->RemoveRV(expr_rv.value()); } - LOG(FATAL) << "TypeError: Invalid type: " << obj->GetTypeKey(); + TVM_FFI_THROW(TypeError) << "Invalid type: " << obj->GetTypeKey(); throw; }); } @@ -159,8 +160,9 @@ TVM_FFI_STATIC_INIT_BLOCK() { if (auto loop_rv = rv.as()) { return self->GetChildBlocks(loop_rv.value()); } - LOG(FATAL) << "TypeError: Cannot evaluate the random variable of type: " - << rv->GetTypeKey() << ". Its value is: " << rv; + TVM_FFI_THROW(TypeError) + << "Cannot evaluate the random variable of type: " << rv->GetTypeKey() + << ". Its value is: " << rv; throw; }) .def_method("s_tir.schedule.ScheduleGetProducers", &ScheduleNode::GetProducers) @@ -183,8 +185,9 @@ TVM_FFI_STATIC_INIT_BLOCK() { } else if (auto block_rv = rv.as()) { return self->AddUnitLoop(block_rv.value()); } else { - LOG(FATAL) << "TypeError: Cannot evaluate the random variable of type: " - << rv->GetTypeKey() << ". Its value is: " << rv; + TVM_FFI_THROW(TypeError) + << "Cannot evaluate the random variable of type: " << rv->GetTypeKey() + << ". Its value is: " << rv; throw; } }); @@ -259,7 +262,7 @@ TVM_FFI_STATIC_INIT_BLOCK() { } else if (auto blocks = target.as>()) { return self->Blockize(blocks.value(), preserve_unit_iters); } - LOG(FATAL) << "Unsupported target type: " << target->GetTypeKey(); + TVM_FFI_THROW(InternalError) << "Unsupported target type: " << target->GetTypeKey(); }) .def("s_tir.schedule.ScheduleTensorize", [](Schedule self, ObjectRef rv, ffi::String intrin, bool preserve_unit_iters) { @@ -268,8 +271,9 @@ TVM_FFI_STATIC_INIT_BLOCK() { } else if (auto loop_rv = rv.as()) { self->Tensorize(loop_rv.value(), intrin, preserve_unit_iters); } else { - LOG(FATAL) << "TypeError: Cannot evaluate the random variable of type: " - << rv->GetTypeKey() << ". Its value is: " << rv; + TVM_FFI_THROW(TypeError) + << "Cannot evaluate the random variable of type: " << rv->GetTypeKey() + << ". Its value is: " << rv; } }); } @@ -286,22 +290,24 @@ TVM_FFI_STATIC_INIT_BLOCK() { if (auto loop_rv = rv.as()) { return self->Annotate(loop_rv.value(), ann_key, ann_val); } - LOG(FATAL) << "TypeError: Cannot evaluate the random variable of type: " - << rv->GetTypeKey() << ". Its value is: " << rv; + TVM_FFI_THROW(TypeError) + << "Cannot evaluate the random variable of type: " << rv->GetTypeKey() + << ". Its value is: " << rv; throw; }) - .def("s_tir.schedule.ScheduleUnannotate", [](Schedule self, ObjectRef rv, - const ffi::String& ann_key) { - if (auto block_rv = rv.as()) { - return self->Unannotate(block_rv.value(), ann_key); - } - if (auto loop_rv = rv.as()) { - return self->Unannotate(loop_rv.value(), ann_key); - } - LOG(FATAL) << "TypeError: Cannot evaluate the random variable of type: " << rv->GetTypeKey() - << ". Its value is: " << rv; - throw; - }); + .def("s_tir.schedule.ScheduleUnannotate", + [](Schedule self, ObjectRef rv, const ffi::String& ann_key) { + if (auto block_rv = rv.as()) { + return self->Unannotate(block_rv.value(), ann_key); + } + if (auto loop_rv = rv.as()) { + return self->Unannotate(loop_rv.value(), ann_key); + } + TVM_FFI_THROW(TypeError) + << "Cannot evaluate the random variable of type: " << rv->GetTypeKey() + << ". Its value is: " << rv; + throw; + }); } /******** (FFI) Layout transformation ********/ diff --git a/src/s_tir/schedule/state.cc b/src/s_tir/schedule/state.cc index e9fdeec1c445..20c319d88453 100644 --- a/src/s_tir/schedule/state.cc +++ b/src/s_tir/schedule/state.cc @@ -95,8 +95,8 @@ bool ProducerCoversConsumer(const ffi::Array& buffer_shape, const ffi::Array& produced_region, const ffi::Array& consumed_region, arith::Analyzer* analyzer) { - ICHECK_EQ(buffer_shape.size(), consumed_region.size()); - ICHECK_EQ(produced_region.size(), consumed_region.size()); + TVM_FFI_ICHECK_EQ(buffer_shape.size(), consumed_region.size()); + TVM_FFI_ICHECK_EQ(produced_region.size(), consumed_region.size()); int ndim = produced_region.size(); for (int i = 0; i < ndim; ++i) { arith::IntSet buffer_size = arith::IntSet::FromMinExtent(0, buffer_shape[i]); @@ -138,9 +138,9 @@ bool ProducerCoversConsumer(const ffi::Array& buffer_shape, * \param new_stmt The statement that replaces the statement inside the sref */ void UpdateSRef(ScheduleStateNode* self, StmtSRefNode* sref, const StmtNode* new_stmt) { - ICHECK(new_stmt->IsInstance() || new_stmt->IsInstance()); + TVM_FFI_ICHECK(new_stmt->IsInstance() || new_stmt->IsInstance()); const StmtNode* old_stmt = sref->stmt; - ICHECK_NE(new_stmt, old_stmt); + TVM_FFI_ICHECK_NE(new_stmt, old_stmt); self->stmt2ref[new_stmt] = ffi::GetRef(sref); self->stmt2ref.erase(sref->stmt); sref->stmt = new_stmt; @@ -235,7 +235,7 @@ class SBlockInfoCollector : private StmtVisitor { // Step 2.1. Extract the path to the scope root std::unordered_map> lca_loc; for (const StmtSRefNode* p = consumer_block_sref.get(); p != limit; p = p->parent) { - ICHECK(p != nullptr); + TVM_FFI_ICHECK(p != nullptr); lca_loc[p] = {}; } // Step 2.2. For each producer, find the LCA of the consumer @@ -249,7 +249,7 @@ class SBlockInfoCollector : private StmtVisitor { } const StmtSRef& producer = dep->src; for (const StmtSRefNode* p = producer.get();; p = p->parent) { - ICHECK(p != nullptr); + TVM_FFI_ICHECK(p != nullptr); auto it = lca_loc.find(p); // Find the first (lowest) position in the ancestor of the consumer, // which is the LCA by definition @@ -373,7 +373,8 @@ class SBlockInfoCollector : private StmtVisitor { /**************** Constructor ****************/ ScheduleState::ScheduleState(IRModule mod, int debug_mask, bool enable_check) { - CHECK_GE(debug_mask, -1) << "ValueError: negative `debug_mask` other than -1 is not supported"; + TVM_FFI_CHECK_GE(debug_mask, -1, ValueError) + << "negative `debug_mask` other than -1 is not supported"; ObjectPtr n = ffi::make_object(); ScheduleStateNode* self = n.get(); // Set `n->mod` @@ -545,8 +546,8 @@ class SRefTreePruner : public StmtVisitor { return; } auto it = self_->stmt2ref.find(op); - ICHECK(it != self_->stmt2ref.end()) - << "IndexError: Cannot find corresponding StmtSRef for the loop:\n" + TVM_FFI_CHECK(it != self_->stmt2ref.end(), IndexError) + << "Cannot find corresponding StmtSRef for the loop:\n" << ffi::GetRef(op); StmtSRef& sref = it->second; // Detect reuse @@ -568,8 +569,8 @@ class SRefTreePruner : public StmtVisitor { return; } auto it = self_->stmt2ref.find(op); - ICHECK(it != self_->stmt2ref.end()) - << "IndexError: Cannot find corresponding StmtSRef for the block:\n" + TVM_FFI_CHECK(it != self_->stmt2ref.end(), IndexError) + << "Cannot find corresponding StmtSRef for the block:\n" << ffi::GetRef(op); StmtSRef& sref = it->second; // Detect reuse @@ -724,11 +725,11 @@ class ChildReplacer : private StmtMutator { static Stmt Replace(const StmtNode* parent_stmt, const StmtNode* child_src_stmt, const Stmt& child_tgt_stmt, int seq_index, bool allow_copy_on_write) { // Check the invariant - ICHECK(child_src_stmt->IsInstance() || // - child_src_stmt->IsInstance()); - ICHECK(child_tgt_stmt->IsInstance() || // - child_tgt_stmt->IsInstance() || // - child_tgt_stmt->IsInstance()); + TVM_FFI_ICHECK(child_src_stmt->IsInstance() || // + child_src_stmt->IsInstance()); + TVM_FFI_ICHECK(child_tgt_stmt->IsInstance() || // + child_tgt_stmt->IsInstance() || // + child_tgt_stmt->IsInstance()); ChildReplacer replacer(child_src_stmt, child_tgt_stmt, seq_index); replacer.allow_copy_on_write_ = allow_copy_on_write; return replacer.CopyOnWriteAndVisit(parent_stmt); @@ -801,7 +802,7 @@ class ChildReplacer : private StmtMutator { new_loop->body = this->VisitStmt(new_loop->body); return For(std::move(new_loop)); } - LOG(FATAL) << "TypeError: Unexpected type: " << parent_stmt->GetTypeKey(); + TVM_FFI_THROW(TypeError) << "Unexpected type: " << parent_stmt->GetTypeKey(); throw; } @@ -825,10 +826,11 @@ void ScheduleStateNode::Replace(const tir::StmtSRef& _src_sref, const Stmt& tgt_ (src_stmt->IsInstance() && tgt_stmt->IsInstance()) || (src_stmt->IsInstance() && tgt_stmt->IsInstance()); if (!input_correct) { - LOG(FATAL) << "TypeError: src_stmt has type: " << src_stmt->GetTypeKey() - << ". tgt_stmt has type: " << tgt_stmt->GetTypeKey() << ".\nsrc_stmt:\n" - << ffi::GetRef(src_stmt) << "\ntgt_stmt:\n" - << tgt_stmt; + TVM_FFI_THROW(TypeError) << "src_stmt has type: " << src_stmt->GetTypeKey() + << ". tgt_stmt has type: " << tgt_stmt->GetTypeKey() + << ".\nsrc_stmt:\n" + << ffi::GetRef(src_stmt) << "\ntgt_stmt:\n" + << tgt_stmt; } } // Rule out the case that no replacement happens @@ -954,7 +956,7 @@ void ScheduleStateNode::Replace(const tir::StmtSRef& _src_sref, const Stmt& tgt_ ffi::MapObj* new_map = new_mod->functions.CopyOnWrite(); // Move out the PrimFunc where the sref belong while ensuring uniqueness PrimFunc ref_new_func = Downcast(std::move(new_map->at(g_var))); - ICHECK(ref_new_func.get() == g_func); + TVM_FFI_ICHECK(ref_new_func.get() == g_func); PrimFuncNode* new_func = ref_new_func.CopyOnWrite(); // If `g_func` was not unique, after the 3 lines above: // `ref_new_func` points to a unique PrimFunc @@ -981,7 +983,7 @@ void ScheduleStateNode::Replace(const tir::StmtSRef& _src_sref, const Stmt& tgt_ } void ScheduleStateNode::DebugVerify() const { - ICHECK_GE(debug_mask, -1); + TVM_FFI_ICHECK_GE(debug_mask, -1); uint32_t flag = (debug_mask != -1) // ? static_cast(debug_mask) // : std::numeric_limits::max(); @@ -998,8 +1000,8 @@ void ScheduleStateNode::DebugVerify() const { SBlockInfo ScheduleStateNode::GetSBlockInfo(const StmtSRef& block_sref) const { TVM_SREF_TO_SBLOCK(block_sref); auto it = this->block_info.find(block_sref); - CHECK(it != this->block_info.end()) - << "IndexError: Cannot find the corresponding SBlockScope to the block sref:\n" + TVM_FFI_CHECK(it != this->block_info.end(), IndexError) + << "Cannot find the corresponding SBlockScope to the block sref:\n" << ffi::GetRef(block_sref->stmt); return it->second; } diff --git a/src/s_tir/schedule/trace.cc b/src/s_tir/schedule/trace.cc index cf1e01b0f11a..24e4636daed1 100644 --- a/src/s_tir/schedule/trace.cc +++ b/src/s_tir/schedule/trace.cc @@ -66,8 +66,8 @@ ffi::Array TranslateInputRVs(const ffi::Array& inputs, return std::nullopt; } const Object* dst = it->second; - ICHECK(dst->IsInstance()) - << "TypeError: Expect 'tir.Var', but gets: " << dst->GetTypeKey(); + TVM_FFI_CHECK(dst->IsInstance(), TypeError) + << "Expect 'tir.Var', but gets: " << dst->GetTypeKey(); return ffi::GetRef(static_cast(dst)); }; @@ -81,7 +81,7 @@ ffi::Array TranslateInputRVs(const ffi::Array& inputs, input.as() || // RV: loop input.as()) { // RV: var auto it = rv_map.find(input.as()); - ICHECK(it != rv_map.end()) << "IndexError: Random variable doesn't exist: " << input; + TVM_FFI_CHECK(it != rv_map.end(), IndexError) << "Random variable doesn't exist: " << input; result.push_back(ffi::GetRef(it->second)); } else if (auto expr = input.try_cast()) { // RV: Expr result.push_back(Substitute(expr.value(), f_subst_with_rv_map)); @@ -91,8 +91,8 @@ ffi::Array TranslateInputRVs(const ffi::Array& inputs, // Recursively convert elements of the array into a new list of ObjectRefs. result.push_back(TranslateInputRVs(arr.value(), rv_map)); } else { - ICHECK(false) << "TypeError: Cannot recognize the type of an input random variable: " - << input.GetTypeKey(); + TVM_FFI_CHECK(false, TypeError) + << "Cannot recognize the type of an input random variable: " << input.GetTypeKey(); throw; } } @@ -125,7 +125,7 @@ ffi::Array TranslateInputRVs( // Case 1. SBlockRV, LoopRV, VarRV results.push_back(it->second); } else { - LOG(FATAL) << "IndexError: Random variable is not defined " << input; + TVM_FFI_THROW(IndexError) << "Random variable is not defined " << input; throw; } } else if (input.as() || input.as()) { @@ -149,7 +149,7 @@ ffi::Array TranslateInputRVs( }); results.push_back(index_map); } else { - LOG(FATAL) << "TypeError: Stringifying is not supported for type: " << input.GetTypeKey(); + TVM_FFI_THROW(TypeError) << "Stringifying is not supported for type: " << input.GetTypeKey(); throw; } } @@ -182,8 +182,9 @@ ffi::Array TranslateInputRVs(const ffi::Array& inputs, continue; } auto opt_str = input.as(); - CHECK(opt_str) << "TypeError: Expect String, but gets: " << input.GetTypeKey(); - CHECK_GT((*opt_str).size(), 0) << "ValueError: Empty string is not allowed in input names"; + TVM_FFI_CHECK(opt_str, TypeError) << "Expect String, but gets: " << input.GetTypeKey(); + TVM_FFI_CHECK_GT((*opt_str).size(), 0, ValueError) + << "Empty string is not allowed in input names"; const char* name = (*opt_str).data(); int64_t size = (*opt_str).size(); if (name[0] == '{' && name[size - 1] == '}') { @@ -201,7 +202,7 @@ ffi::Array TranslateInputRVs(const ffi::Array& inputs, results.push_back(index_map); continue; } else { - LOG(FATAL) << "TypeError: Unexpected object: " << obj.GetTypeKey(); + TVM_FFI_THROW(TypeError) << "Unexpected object: " << obj.GetTypeKey(); throw; } } @@ -212,7 +213,8 @@ ffi::Array TranslateInputRVs(const ffi::Array& inputs, } // Case 0 & 1. None, SBlockRV, LoopRV, VarRV auto it = named_rvs.find(name); - CHECK(it != named_rvs.end()) << "ValueError: The random variable is not defined: " << name; + TVM_FFI_CHECK(it != named_rvs.end(), ValueError) + << "The random variable is not defined: " << name; results.push_back(it->second); } return results; @@ -222,12 +224,12 @@ ffi::Array TranslateInputRVs(const ffi::Array& inputs, void TranslateAddOutputRVs(const ffi::Array& old_outputs, const ffi::Array& new_outputs, std::unordered_map* rv_map) { - ICHECK_EQ(old_outputs.size(), new_outputs.size()); + TVM_FFI_ICHECK_EQ(old_outputs.size(), new_outputs.size()); int n = old_outputs.size(); for (int i = 0; i < n; ++i) { const Object* old_rv = old_outputs[i].as(); const Object* new_rv = new_outputs[i].as(); - ICHECK(old_rv != nullptr && new_rv != nullptr); + TVM_FFI_ICHECK(old_rv != nullptr && new_rv != nullptr); (*rv_map)[old_rv] = new_rv; } } @@ -239,9 +241,8 @@ ffi::Array TranslateAddOutputRVs( results.reserve(outputs.size()); for (const Any& output : outputs) { int i = rv_names->size(); - ICHECK(!rv_names->count(output.cast())) - << "ValueError: The random variable has been produced once: " - << rv_names->at(output.cast()); + TVM_FFI_CHECK(!rv_names->count(output.cast()), ValueError) + << "The random variable has been produced once: " << rv_names->at(output.cast()); ffi::String result; if (output == nullptr) { result = "_"; @@ -252,8 +253,8 @@ ffi::Array TranslateAddOutputRVs( } else if (output.as()) { result = "v" + std::to_string(i); } else { - LOG(FATAL) << "TypeError: Cannot recognize the type of the random variable: " - << output.GetTypeKey(); + TVM_FFI_THROW(TypeError) << "Cannot recognize the type of the random variable: " + << output.GetTypeKey(); throw; } results.push_back(result); @@ -265,7 +266,7 @@ ffi::Array TranslateAddOutputRVs( void TranslateAddOutputRVs(const ffi::Array& old_outputs, const ffi::Array& new_outputs, std::unordered_map* named_rvs) { - ICHECK_EQ(old_outputs.size(), new_outputs.size()); + TVM_FFI_ICHECK_EQ(old_outputs.size(), new_outputs.size()); int n = old_outputs.size(); for (int i = 0; i < n; ++i) { named_rvs->emplace(Downcast(old_outputs[i]), new_outputs[i].cast()); @@ -387,16 +388,16 @@ void Trace::ApplyJSONToSchedule(ObjectRef json, Schedule sch) { // Parse `json` into `json_insts` and `json_decisions` try { const ffi::ArrayObj* arr = json.as(); - ICHECK(arr && arr->size() == 2); + TVM_FFI_ICHECK(arr && arr->size() == 2); const auto* arr0 = arr->at(0).as(); const auto* arr1 = arr->at(1).as(); - ICHECK(arr0 && arr1); + TVM_FFI_ICHECK(arr0 && arr1); json_insts = ffi::GetRef>(arr0); json_decisions = ffi::GetRef>(arr1); } catch (const tvm::Error& e) { - LOG(FATAL) << "ValueError: The json entry of a trace should contain two arrays, an array of " - "instructions and an array of decisions, but gets: " - << json; + TVM_FFI_THROW(ValueError) << "The json entry of a trace should contain two arrays, an array of " + "instructions and an array of decisions, but gets: " + << json; throw; } // Parse `json_decisions` @@ -406,15 +407,15 @@ void Trace::ApplyJSONToSchedule(ObjectRef json, Schedule sch) { Any decision{nullptr}; try { const ffi::ArrayObj* arr = decision_entry.as(); - ICHECK(arr && arr->size() == 2); + TVM_FFI_ICHECK(arr && arr->size() == 2); auto arr0 = arr->at(0).try_cast(); - ICHECK(arr0); + TVM_FFI_ICHECK(arr0); index = arr0.value()->value; decision = arr->at(1); } catch (const tvm::Error& e) { - LOG(FATAL) << "ValueError: Each entry of a json decision should be a tuple [index, " - "decision], but gets: " - << decision_entry; + TVM_FFI_THROW(ValueError) << "Each entry of a json decision should be a tuple [index, " + "decision], but gets: " + << decision_entry; throw; } decisions[index] = std::move(decision); @@ -430,16 +431,16 @@ void Trace::ApplyJSONToSchedule(ObjectRef json, Schedule sch) { // Parse the entry try { const auto* arr = inst_entry.as(); - ICHECK(arr && arr->size() == 4); + TVM_FFI_ICHECK(arr && arr->size() == 4); ffi::String arr0 = arr->at(0).cast(); kind = InstructionKind::Get(arr0); inputs = arr->at(1).cast>(); attrs = arr->at(2).cast>(); outputs = arr->at(3).cast>(); } catch (const tvm::Error& e) { - LOG(FATAL) << "ValueError: Each entry of a json instruction should be a tuple [inst_name, " - "inputs, attrs, outputs], but gets: " - << inst_entry << "\nThe error is: " << e.what(); + TVM_FFI_THROW(ValueError) << "Each entry of a json instruction should be a tuple [inst_name, " + "inputs, attrs, outputs], but gets: " + << inst_entry << "\nThe error is: " << e.what(); throw; } // Parse inputs @@ -524,7 +525,7 @@ Trace TraceNode::Simplified(bool remove_postproc) const { TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& obj, ReprPrinter* p) { const auto* self = obj.as(); - ICHECK_NOTNULL(self); + TVM_FFI_ICHECK_NOTNULL(self); p->stream << "# from tvm import s_tir\n"; p->stream << "def apply_trace(sch: s_tir.Schedule) -> None:\n"; ffi::Array repr = self->AsPython(/*remove_postproc=*/false); diff --git a/src/s_tir/schedule/traced_schedule.cc b/src/s_tir/schedule/traced_schedule.cc index f6e91ebf85b5..68541fc26ddc 100644 --- a/src/s_tir/schedule/traced_schedule.cc +++ b/src/s_tir/schedule/traced_schedule.cc @@ -125,9 +125,10 @@ SBlockRV TracedScheduleNode::GetSBlock(const ffi::String& name, } else if (func_working_on_.defined()) { gv = this->func_working_on_.value(); } else { - LOG(FATAL) << "ValueError: `get_sblock` does not know which function to be working on. Please " - "specify the function name explicitly, or call `work_on` to specify the function " - "before using `get_sblock`."; + TVM_FFI_THROW(ValueError) + << "`get_sblock` does not know which function to be working on. Please " + "specify the function name explicitly, or call `work_on` to specify the function " + "before using `get_sblock`."; } SBlockRV result = ConcreteScheduleNode::GetSBlock(name, func_name); diff --git a/src/s_tir/schedule/transform.cc b/src/s_tir/schedule/transform.cc index 0e1c626b6ddf..dd89f7e5c691 100644 --- a/src/s_tir/schedule/transform.cc +++ b/src/s_tir/schedule/transform.cc @@ -314,7 +314,7 @@ void LeafBlockRemovalPlan(const ScheduleState& self, const StmtSRef& leaf_block_ return; } } - ICHECK(sref != nullptr && sref->stmt != nullptr); + TVM_FFI_ICHECK(sref != nullptr && sref->stmt != nullptr); const auto* leaf_block = TVM_SREF_TO_SBLOCK(leaf_block_sref); const auto* scope_block = TVM_SREF_TO_SBLOCK(sref); throw OnlyLeafError(self->mod, ffi::GetRef(leaf_block), ffi::GetRef(scope_block)); @@ -355,9 +355,9 @@ ffi::Optional TileWithTensorIntrin(const s_tir::Schedule& sch, // The original producer is input. continue; } - ICHECK_EQ(the_original_producers.size(), 1u); + TVM_FFI_ICHECK_EQ(the_original_producers.size(), 1u); auto the_original_producer = the_original_producers[0]; - ICHECK(original_producers.count(sch->GetSRef(the_original_producer).get())); + TVM_FFI_ICHECK(original_producers.count(sch->GetSRef(the_original_producer).get())); inlined_producers.push_back(the_original_producer); } for (const auto& consumer : sch->GetConsumers(block_rv)) { @@ -371,9 +371,9 @@ ffi::Optional TileWithTensorIntrin(const s_tir::Schedule& sch, // The original consumer is output. continue; } - ICHECK_EQ(the_original_consumers.size(), 1u); + TVM_FFI_ICHECK_EQ(the_original_consumers.size(), 1u); auto the_original_consumer = the_original_consumers[0]; - ICHECK(original_consumers.count(sch->GetSRef(the_original_consumer).get())); + TVM_FFI_ICHECK(original_consumers.count(sch->GetSRef(the_original_consumer).get())); inlined_consumers.push_back(consumer); } @@ -405,22 +405,22 @@ ffi::Optional TileWithTensorIntrin(const s_tir::Schedule& sch, const tir::StmtSRef& block_loop_sref = kv.first; const tir::ForNode* block_loop = block_loop_sref->StmtAs(); const tir::ForNode* desc_loop = kv.second.get(); - ICHECK(block_loop != nullptr && desc_loop != nullptr); + TVM_FFI_ICHECK(block_loop != nullptr && desc_loop != nullptr); // Extract the loop extent PrimExpr block_extent = analyzer.Simplify(block_loop->extent); PrimExpr desc_extent = analyzer.Simplify(desc_loop->extent); const auto* int_block_extent = block_extent.as(); const auto* int_desc_extent = desc_extent.as(); - ICHECK(int_block_extent != nullptr && int_desc_extent != nullptr); + TVM_FFI_ICHECK(int_block_extent != nullptr && int_desc_extent != nullptr); // Check divisibility int64_t total = int_block_extent->value; int64_t inner = int_desc_extent->value; - ICHECK_EQ(total % inner, 0); + TVM_FFI_ICHECK_EQ(total % inner, 0); // Do the split. Leave the outer extent as std::nullopt (unspecified) so that the split factors // can be used for different extents (needed during tuning). ffi::Array split = sch->Split(loop2rv.at(block_loop_sref), {std::nullopt, Integer(inner)}); - ICHECK_EQ(split.size(), 2); + TVM_FFI_ICHECK_EQ(split.size(), 2); inner_loops.insert(sch->GetSRef(split[1]).operator->()); // The inner split will be reordered to the loop domain that is tensorized int desc_loop_index = info->desc_loop_indexer.at(ffi::GetRef(desc_loop)).IntValue(); @@ -439,7 +439,7 @@ ffi::Optional TileWithTensorIntrin(const s_tir::Schedule& sch, } reorder_list.insert(reorder_list.end(), reorder_suffix.begin(), reorder_suffix.end()); sch->Reorder(reorder_list); - ICHECK(!reorder_suffix.empty()); + TVM_FFI_ICHECK(!reorder_suffix.empty()); return reorder_suffix[0]; } diff --git a/src/s_tir/schedule/utils.h b/src/s_tir/schedule/utils.h index c3abdae6c990..715e34b09f61 100644 --- a/src/s_tir/schedule/utils.h +++ b/src/s_tir/schedule/utils.h @@ -111,7 +111,7 @@ inline bool CanRelaxStorageUnderThread(const runtime::StorageScope& storage_scop * \return The removal result */ inline Stmt RemoveFromSeqStmt(const SeqStmt& seq, const Stmt& to_remove) { - ICHECK_GT(seq->size(), 1); + TVM_FFI_ICHECK_GT(seq->size(), 1); ffi::Array new_stmts; new_stmts.reserve(seq->size()); for (const Stmt& stmt : seq->seq) { @@ -279,7 +279,7 @@ inline ffi::Optional GetAnn(const StmtSRef& sref, const ffi::String& } else if (const auto* block = sref->StmtAs()) { return GetAnn(block, ann_key); } else { - LOG(FATAL) << "TypeError: Unknown type of sref: " << sref->stmt->GetTypeKey(); + TVM_FFI_THROW(TypeError) << "Unknown type of sref: " << sref->stmt->GetTypeKey(); throw; } } @@ -347,7 +347,7 @@ inline void ReorderAndFuseReductionLoops(const s_tir::Schedule& sch, } } // Step 3. Apply reordering if new_order differs from the original order. - ICHECK_EQ(new_order.size(), loops.size()); + TVM_FFI_ICHECK_EQ(new_order.size(), loops.size()); for (size_t i = 0; i < loops.size(); ++i) { if (!new_order[i].same_as(loops[i])) { sch->Reorder(new_order); @@ -355,7 +355,8 @@ inline void ReorderAndFuseReductionLoops(const s_tir::Schedule& sch, } } // Step 4. Fuse all the reduction loops if there are multiple reduction loops. - CHECK(!reduction_loops.empty()) << "ValueError: There should be at least one reduction loop"; + TVM_FFI_CHECK(!reduction_loops.empty(), ValueError) + << "There should be at least one reduction loop"; if (reduction_loops.size() > 1) { *fused_reduce_loop = sch->Fuse(reduction_loops); } else { @@ -374,7 +375,7 @@ inline ffi::String BufferIndexType2Str(BufferIndexType buffer_index_type) { if (buffer_index_type == BufferIndexType::kRead) { return "read"; } else { - ICHECK(buffer_index_type == BufferIndexType::kWrite); + TVM_FFI_ICHECK(buffer_index_type == BufferIndexType::kWrite); return "write"; } } diff --git a/src/s_tir/transform/annotate_irregular_loop.cc b/src/s_tir/transform/annotate_irregular_loop.cc index 76c41a25b612..b21245ea15c9 100644 --- a/src/s_tir/transform/annotate_irregular_loop.cc +++ b/src/s_tir/transform/annotate_irregular_loop.cc @@ -41,12 +41,12 @@ class IrregularLoopAnnotator : public StmtMutator { has_jump_ = false; For res = Downcast(StmtMutator::VisitStmt_(op)); if (has_jump_) { - CHECK(op->kind == ForKind::kSerial) + TVM_FFI_ICHECK(op->kind == ForKind::kSerial) << "Loop kind " << op->kind << " is invalid for irregular loop " << op->loop_var; for (const char* key : {tir::attr::pragma_auto_unroll_max_step, tir::attr::pragma_unroll_explicit, tir::attr::pragma_loop_partition_hint, tir::attr::software_pipeline_stage}) { - CHECK(!res->annotations.count(key)) + TVM_FFI_ICHECK(!res->annotations.count(key)) << "Annotation `" << key << "` is invalid for irregular loop " << op->loop_var; } res.CopyOnWrite()->annotations.Set(tir::attr::irregular_loop_mark, 1); diff --git a/src/s_tir/transform/bound_checker.cc b/src/s_tir/transform/bound_checker.cc index 2f2061d9c182..dbee2effdb7e 100644 --- a/src/s_tir/transform/bound_checker.cc +++ b/src/s_tir/transform/bound_checker.cc @@ -194,7 +194,7 @@ class BoundChecker : public StmtExprMutator { ffi::Array indices = pair.first; ffi::Array shape = pair.second; - ICHECK_EQ(indices.size(), shape.size()) + TVM_FFI_ICHECK_EQ(indices.size(), shape.size()) << "Mismatch between dimension of physical shape and physical indices"; for (size_t i = 0; i < indices.size(); i++) { diff --git a/src/s_tir/transform/canonicalize_loop.cc b/src/s_tir/transform/canonicalize_loop.cc index 5d737d7640d1..99ee0f614b14 100644 --- a/src/s_tir/transform/canonicalize_loop.cc +++ b/src/s_tir/transform/canonicalize_loop.cc @@ -51,7 +51,8 @@ class LoopCanonicalizer : public StmtExprMutator { // report warning for negative step, since it would be a forever loop if (!analyzer_.CanProveGreaterEqual(step, 1)) { // TODO(tvm): prove dynamic shaped step - LOG(FATAL) << "Loop step for " << op->loop_var << " may not be positive: " << step; + TVM_FFI_THROW(InternalError) + << "Loop step for " << op->loop_var << " may not be positive: " << step; } new_iter_info_[loop_var] = std::make_pair(step, op->min); diff --git a/src/s_tir/transform/compact_buffer_region.cc b/src/s_tir/transform/compact_buffer_region.cc index dba8e25ebec7..c9a4a583bcad 100644 --- a/src/s_tir/transform/compact_buffer_region.cc +++ b/src/s_tir/transform/compact_buffer_region.cc @@ -227,8 +227,8 @@ class BufferAccessRegionCollector : public StmtExprVisitor { void VisitStmt_(const SBlockNode* op) final { // Step 0. Check there is no init part and block is opaque - ICHECK(!op->init.defined()); - ICHECK_EQ(op->iter_vars.size(), 0) << "CompactBufferRegion only works on opaque blocks"; + TVM_FFI_ICHECK(!op->init.defined()); + TVM_FFI_ICHECK_EQ(op->iter_vars.size(), 0) << "CompactBufferRegion only works on opaque blocks"; // Step 1. Record and update current read/write region annotations std::unordered_map, ObjectPtrHash, ObjectPtrEqual> cur_access_annotations; @@ -286,7 +286,7 @@ class BufferAccessRegionCollector : public StmtExprVisitor { explicit_access_annotations_.clear(); // Step 8. Update buffer_access_region_ from relaxed_accesses_ for inner buffers. for (const Buffer& buffer : op->alloc_buffers) { - ICHECK_EQ(var2buffer_[buffer->data].size(), 1) + TVM_FFI_ICHECK_EQ(var2buffer_[buffer->data].size(), 1) << "Block allocation buffer shoud not be alised"; SimplifyAndNarrowBufferRegionFromNDIntSet(buffer); } @@ -341,7 +341,7 @@ class BufferAccessRegionCollector : public StmtExprVisitor { /*! \brief Record information on the buffer defining point. */ void VisitBufferDef(const Var& buffer_data) { auto it = buffer_scope_depth_.find(buffer_data); - ICHECK(it == buffer_scope_depth_.end()) << buffer_data << " has duplicate definitions"; + TVM_FFI_ICHECK(it == buffer_scope_depth_.end()) << buffer_data << " has duplicate definitions"; buffer_scope_depth_.insert(it, {buffer_data, ancestor_iters_.size()}); } @@ -360,7 +360,7 @@ class BufferAccessRegionCollector : public StmtExprVisitor { continue; } auto dom_it = dom_map_.find(v); - ICHECK(dom_it != dom_map_.end()) + TVM_FFI_ICHECK(dom_it != dom_map_.end()) << "Could not find domain for loop variable " << v->name_hint; non_relaxed[i] = dom_it->second; dom_map_.erase(dom_it); @@ -428,7 +428,7 @@ class BufferAccessRegionCollector : public StmtExprVisitor { */ void SimplifyAndNarrowBufferRegionFromNDIntSet(const Buffer& buffer) { auto it = relaxed_accesses_.find(buffer); - ICHECK(it != relaxed_accesses_.end()) + TVM_FFI_ICHECK(it != relaxed_accesses_.end()) << buffer << " is allocated but not accessed within block scope"; const ffi::Array& original_shape = buffer->shape; @@ -565,7 +565,7 @@ class BufferCompactor : public StmtExprMutator { Stmt VisitStmt_(const SBlockNode* op) final { // Step 0. Check there is no Init part. - ICHECK(!op->init.defined()); + TVM_FFI_ICHECK(!op->init.defined()); // Step 1. Reallocate and rewrite alloc_buffers, also update BufferAllocInfo. ffi::Array alloc_buffers = op->alloc_buffers.Map([this](const Buffer& buf) { return RewriteAllocBuffer(buf); }); @@ -603,7 +603,7 @@ class BufferCompactor : public StmtExprMutator { } ffi::Array new_shape = GetBufferAllocationShape(new_buffer); auto n = allocate.CopyOnWrite(); - ICHECK(n->buffer_var.same_as(new_buffer->data)); + TVM_FFI_ICHECK(n->buffer_var.same_as(new_buffer->data)); n->extents = new_shape; return allocate; } @@ -622,7 +622,7 @@ class BufferCompactor : public StmtExprMutator { return; } const BufferAllocInfo& info = it->second; - ICHECK_EQ(indices->size(), info.region.size()); + TVM_FFI_ICHECK_EQ(indices->size(), info.region.size()); int ndim = info.region.size(); ffi::Array new_indices; new_indices.reserve(ndim); @@ -640,7 +640,7 @@ class BufferCompactor : public StmtExprMutator { return; } const BufferAllocInfo& info = it->second; - ICHECK_EQ(region->size(), info.region.size()); + TVM_FFI_ICHECK_EQ(region->size(), info.region.size()); Region new_region; new_region.reserve(info.region.size()); for (size_t i = 0; i < info.region.size(); ++i) { @@ -683,7 +683,7 @@ ffi::Array CalcStrides(const BufferAllocInfo& alloc_info, const ffi::Array& shape) { std::vector strides; if (alloc_info.dim_aligns.size()) { - ICHECK(alloc_info.dim_aligns.size() == shape.size()); + TVM_FFI_ICHECK(alloc_info.dim_aligns.size() == shape.size()); strides.resize(shape.size()); PrimExpr stride = make_const(shape[0].dtype(), 1); for (size_t i = shape.size(); i != 0; --i) { diff --git a/src/s_tir/transform/convert_blocks_to_opaque.cc b/src/s_tir/transform/convert_blocks_to_opaque.cc index 32971922f973..53b47735b89d 100644 --- a/src/s_tir/transform/convert_blocks_to_opaque.cc +++ b/src/s_tir/transform/convert_blocks_to_opaque.cc @@ -47,7 +47,7 @@ class OpaqueBlockConverter : public StmtExprMutator { OpaqueBlockConverter() = default; PrimExpr VisitExpr_(const VarNode* var) final { - CHECK(!forbidden_iter_vars_.count(var)) + TVM_FFI_ICHECK(!forbidden_iter_vars_.count(var)) << "Variable " << var->name_hint << " occurs in the predicate or iter_values of a block, " << "but isn't defined until the body of the block"; @@ -59,7 +59,7 @@ class OpaqueBlockConverter : public StmtExprMutator { } Stmt VisitStmt_(const SBlockNode* block) final { - ICHECK(!block->init.defined()) + TVM_FFI_ICHECK(!block->init.defined()) << "Block Init part is not allowed in pass ConvertBlocksToOpaque"; SBlock new_block = Downcast(StmtExprMutator::VisitStmt_(block)); if (!new_block->iter_vars.empty()) { @@ -70,7 +70,7 @@ class OpaqueBlockConverter : public StmtExprMutator { Stmt VisitStmt_(const SBlockRealizeNode* realize) final { const auto* block_op = realize->block.get(); - ICHECK(!block_op->init.defined()); + TVM_FFI_ICHECK(!block_op->init.defined()); // Step 1. Visit the predicate and iter_values, without any variable bindings for (const auto& iter : block_op->iter_vars) forbidden_iter_vars_.insert(iter->var.get()); @@ -80,7 +80,7 @@ class OpaqueBlockConverter : public StmtExprMutator { for (const auto& iter : block_op->iter_vars) forbidden_iter_vars_.erase(iter->var.get()); // Step 2. Update "block vars => binding values" for substitution. - ICHECK_EQ(block_op->iter_vars.size(), iter_values.size()); + TVM_FFI_ICHECK_EQ(block_op->iter_vars.size(), iter_values.size()); for (int i = 0, n = block_op->iter_vars.size(); i < n; ++i) { IterVar block_var = block_op->iter_vars[i]; PrimExpr v = this->VisitExpr(iter_values[i]); diff --git a/src/s_tir/transform/default_gpu_schedule.cc b/src/s_tir/transform/default_gpu_schedule.cc index fee5cd8361df..216182e0f434 100644 --- a/src/s_tir/transform/default_gpu_schedule.cc +++ b/src/s_tir/transform/default_gpu_schedule.cc @@ -46,7 +46,7 @@ void ThreadBind(s_tir::Schedule sch, const s_tir::SBlockRV& block, int64_t max_t // when there is no loops, tir will add a dummy iter var for the block // so loops.size() == 0 && iters.size() == 1 - ICHECK(loops.size() == iters.size() || (loops.size() == 0 && iters.size() == 1)); + TVM_FFI_ICHECK(loops.size() == iters.size() || (loops.size() == 0 && iters.size() == 1)); ffi::Array data_parallel_loops; // only fuse data parallel loops @@ -137,12 +137,13 @@ Pass DefaultGPUSchedule() { if (func_target.defined()) { target = func_target.value(); } - ICHECK(target.defined()) << "The target is missing either in the current context or in " - "the prim_func's attribute."; + TVM_FFI_ICHECK(target.defined()) + << "The target is missing either in the current context or in " + "the prim_func's attribute."; // get the max thread per block from target. ffi::Optional opt_max_thread_per_block = target->GetAttr("max_num_threads"); - ICHECK(opt_max_thread_per_block.defined()) + TVM_FFI_ICHECK(opt_max_thread_per_block.defined()) << "max_num_threads is not set for target " << target; int64_t max_thread_per_block = opt_max_thread_per_block.value().IntValue(); diff --git a/src/s_tir/transform/hoist_expression.cc b/src/s_tir/transform/hoist_expression.cc index add5b663bcde..7858ff0e14dd 100644 --- a/src/s_tir/transform/hoist_expression.cc +++ b/src/s_tir/transform/hoist_expression.cc @@ -283,9 +283,9 @@ class HoistInfoCollector : public StmtExprVisitor { } void VisitBinding(Var var, PrimExpr value, HoistedLetBindings hoist_from) { - ICHECK_EQ(let_var_to_loop_vars.count(var.get()), 0) + TVM_FFI_ICHECK_EQ(let_var_to_loop_vars.count(var.get()), 0) << "Multiple nested definitions of variable " << var; - ICHECK_EQ(let_var_to_let_vars.count(var.get()), 0) + TVM_FFI_ICHECK_EQ(let_var_to_let_vars.count(var.get()), 0) << "Multiple nested definitions of variable " << var; if (auto info = FindHoistDestination(value)) { @@ -495,7 +495,7 @@ class ExpressionHoister : public arith::IRMutatorWithAnalyzer { Stmt stmt = Parent::VisitStmt_(op); auto it = loop_info_lookup.find(op); - ICHECK(it != loop_info_lookup.end()) + TVM_FFI_ICHECK(it != loop_info_lookup.end()) << "Could not find pre-pass information for loop over " << op->loop_var; return WrapHoistedStatements(stmt, it->second); } diff --git a/src/s_tir/transform/inject_double_buffer.cc b/src/s_tir/transform/inject_double_buffer.cc index 465ccaf173ad..99da2ac51b97 100644 --- a/src/s_tir/transform/inject_double_buffer.cc +++ b/src/s_tir/transform/inject_double_buffer.cc @@ -117,14 +117,14 @@ class DoubleBufferInjector : public StmtExprMutator { StorageEntry& entry = it->second; entry.scope = GetPtrStorageScope(op->buffer_var); - ICHECK_EQ(op->extents.size(), 1) << "InjectDoubleBuffer expects flat 1-d buffers. " - << "Has FlattenBuffer been run?"; + TVM_FFI_ICHECK_EQ(op->extents.size(), 1) << "InjectDoubleBuffer expects flat 1-d buffers. " + << "Has FlattenBuffer been run?"; entry.stride = op->extents[0]; Stmt stmt = StmtExprMutator::VisitStmt_(op); op = stmt.as(); ffi::Array new_extents = {op->extents[0] * make_const(op->extents[0].dtype(), 2)}; - ICHECK(entry.loop != nullptr); + TVM_FFI_ICHECK(entry.loop != nullptr); auto& alloc_nest = loop_allocs_[entry.loop]; alloc_nest.emplace_back(Allocate(op->buffer_var, op->dtype, new_extents, op->condition, Evaluate(0), op->annotations)); @@ -149,9 +149,9 @@ class DoubleBufferInjector : public StmtExprMutator { const ForNode* old_loop = stmt.as(); if (split_loop_ != 0) { // Explicitly unroll the loop - ICHECK(split_loop_ % 2 == 0 || split_loop_ == 1) + TVM_FFI_ICHECK(split_loop_ % 2 == 0 || split_loop_ == 1) << "It is better to split with multiple of 2"; - ICHECK(is_zero(old_loop->min)); + TVM_FFI_ICHECK(is_zero(old_loop->min)); PrimExpr zero = old_loop->min; PrimExpr new_ext = old_loop->extent - make_const(old_loop->loop_var.dtype(), 1); PrimExpr factor = make_const(new_ext.dtype(), split_loop_); @@ -191,11 +191,11 @@ class DoubleBufferInjector : public StmtExprMutator { auto it = dbuffer_info_.find(node->buffer->data.get()); if (it != dbuffer_info_.end()) { const StorageEntry& e = it->second; - ICHECK(in_double_buffer_scope_); - ICHECK(e.switch_write_var.defined()); + TVM_FFI_ICHECK(in_double_buffer_scope_); + TVM_FFI_ICHECK(e.switch_write_var.defined()); - ICHECK_EQ(node->indices.size(), 1) << "InjectDoubleBuffer expects flat 1-d buffers. " - << "Has FlattenBuffer been run?"; + TVM_FFI_ICHECK_EQ(node->indices.size(), 1) << "InjectDoubleBuffer expects flat 1-d buffers. " + << "Has FlattenBuffer been run?"; auto writer = node.CopyOnWrite(); writer->buffer = GetRemappedBuffer(node->buffer, e.stride); @@ -211,10 +211,10 @@ class DoubleBufferInjector : public StmtExprMutator { auto it = dbuffer_info_.find(node->buffer->data.get()); if (it != dbuffer_info_.end()) { const StorageEntry& e = it->second; - ICHECK(e.switch_read_var.defined()); + TVM_FFI_ICHECK(e.switch_read_var.defined()); - ICHECK_EQ(node->indices.size(), 1) << "InjectDoubleBuffer expects flat 1-d buffers. " - << "Has FlattenBuffer been run?"; + TVM_FFI_ICHECK_EQ(node->indices.size(), 1) << "InjectDoubleBuffer expects flat 1-d buffers. " + << "Has FlattenBuffer been run?"; auto writer = node.CopyOnWrite(); writer->buffer = GetRemappedBuffer(node->buffer, e.stride); @@ -231,13 +231,13 @@ class DoubleBufferInjector : public StmtExprMutator { return it->second; } - ICHECK(stride.defined()); + TVM_FFI_ICHECK(stride.defined()); // TODO(Lunderberg): Move this pass to before // FlattenBuffer. That will simplify the // implementation, to be the insertion of a new dimension for the // buffer, rather than adjusting the other indices. - ICHECK_EQ(buf->shape.size(), 1) << "InjectDoubleBuffer expects flat 1-d buffers. " - << "Has FlattenBuffer been run?"; + TVM_FFI_ICHECK_EQ(buf->shape.size(), 1) << "InjectDoubleBuffer expects flat 1-d buffers. " + << "Has FlattenBuffer been run?"; // Stride gives the distance between the two halves of the // double-buffer, not the stride of the buffer's index. @@ -248,14 +248,14 @@ class DoubleBufferInjector : public StmtExprMutator { } PrimExpr VisitExpr_(const VarNode* op) final { - ICHECK(!dbuffer_info_.count(op)); + TVM_FFI_ICHECK(!dbuffer_info_.count(op)); return ffi::GetRef(op); } private: Stmt MakeProducer(const AttrStmtNode* op) { const Var buffer = Downcast(op->node); - ICHECK_NE(loop_nest_.size(), 0U) << "Double buffer scope must be inside a loop"; + TVM_FFI_ICHECK_NE(loop_nest_.size(), 0U) << "Double buffer scope must be inside a loop"; auto it = dbuffer_info_.find(buffer.get()); if (it == dbuffer_info_.end()) { LOG(WARNING) << "Skip double buffer scope " << op->node; diff --git a/src/s_tir/transform/inject_permuted_layout.cc b/src/s_tir/transform/inject_permuted_layout.cc index 23f68f3b75b6..ee0479bb914c 100644 --- a/src/s_tir/transform/inject_permuted_layout.cc +++ b/src/s_tir/transform/inject_permuted_layout.cc @@ -61,7 +61,7 @@ class PermutedLayoutInjector : private IRMutatorWithAnalyzer { using IRMutatorWithAnalyzer::VisitStmt_; ffi::Array PermuteIndices(PrimExpr row_idx, PrimExpr col_idx, int row_size) { - ICHECK(permute_); + TVM_FFI_ICHECK(permute_); // Index after vectorizing by 8 PrimExpr col_idx_outer = floordiv(col_idx, VECTORIZE_FACTOR), col_idx_inner = floormod(col_idx, VECTORIZE_FACTOR); @@ -81,7 +81,7 @@ class PermutedLayoutInjector : private IRMutatorWithAnalyzer { auto row_idx_sub = floormod(row_idx, 8); new_col_idx_outer = col_idx_outer ^ row_idx_sub; } else { - ICHECK(row_size % 32 == 0); + TVM_FFI_ICHECK(row_size % 32 == 0); // Use 8 * 4 permuted layout // Every number below corresponds to 8 consecutive fp16 number in shared mem, i.e. one read // Every row below corresponds to 16 banks @@ -113,7 +113,7 @@ class PermutedLayoutInjector : private IRMutatorWithAnalyzer { } else if (auto opt_val = annotation.try_cast()) { return *opt_val != 0; } else { - LOG(FATAL) << "Invalid permuted layout annotation: " << annotation; + TVM_FFI_THROW(InternalError) << "Invalid permuted layout annotation: " << annotation; } } @@ -145,7 +145,7 @@ class PermutedLayoutInjector : private IRMutatorWithAnalyzer { } int CheckAndGetBufferRowSize(Buffer buffer) { - CHECK(buffer->shape.size() >= 2) + TVM_FFI_ICHECK(buffer->shape.size() >= 2) << "The dimension of Buffer \"" << buffer->name << "\" with shape " << buffer->shape << " should be at least 2"; @@ -154,10 +154,10 @@ class PermutedLayoutInjector : private IRMutatorWithAnalyzer { auto buffer_col_size = buffer->shape[dim - 2].as()->value; if (buffer_row_size % 64 != 0) { - CHECK(buffer_row_size % 32 == 0) + TVM_FFI_ICHECK(buffer_row_size % 32 == 0) << "Permuted Layout for Buffer \"" << buffer->name << "\" with shape " << buffer->shape << " is not supported since its second dimension is not divisible by 32"; - CHECK(buffer_col_size % 2 == 0) + TVM_FFI_ICHECK(buffer_col_size % 2 == 0) << "Permuted Layout for Buffer \"" << buffer->name << "\" with shape " << buffer->shape << " is not supported since its first dimension is not divisible by 2 and second " "dimension is not divisible by 64"; @@ -221,14 +221,14 @@ class PermutedLayoutInjector : private IRMutatorWithAnalyzer { ffi::Optional offset = std::nullopt) { // The 2th arg of T.tvm_access_ptr call is offset, we set it to 0 and accumulate it to // smem_offset - CHECK(access_ptr->IsInstance()) + TVM_FFI_ICHECK(access_ptr->IsInstance()) << "Invalid access ptr for permuted layout: " << access_ptr; auto access_ptr_call = Downcast(access_ptr); - CHECK(access_ptr_call->op.same_as(builtin::tvm_access_ptr())) + TVM_FFI_ICHECK(access_ptr_call->op.same_as(builtin::tvm_access_ptr())) << "Invalid access ptr for permuted layout: " << access_ptr; auto buffer_map_iter = buffer_map_.find(Downcast(access_ptr_call->args[1])); - CHECK(buffer_map_iter != buffer_map_.end()) + TVM_FFI_ICHECK(buffer_map_iter != buffer_map_.end()) << "The buffer corresponding to data Var " << access_ptr_call->args[1] << " is not found"; int buffer_row_size = CheckAndGetBufferRowSize(buffer_map_iter->second); @@ -277,7 +277,7 @@ class PermutedLayoutInjector : private IRMutatorWithAnalyzer { new_call->args.Set(2, new_access_ptr); return call; } else { - LOG(FATAL) << "Invalid call node: " << call; + TVM_FFI_THROW(InternalError) << "Invalid call node: " << call; } } diff --git a/src/s_tir/transform/inject_ptx_async_copy.cc b/src/s_tir/transform/inject_ptx_async_copy.cc index c9b1e42d7fee..6e0257d248fd 100644 --- a/src/s_tir/transform/inject_ptx_async_copy.cc +++ b/src/s_tir/transform/inject_ptx_async_copy.cc @@ -41,7 +41,7 @@ class PTXAsyncCopyInjector : public StmtMutator { public: Stmt VisitStmt_(const AttrStmtNode* attr) { if (attr->attr_key == tir::attr::async_scope) { - ICHECK(in_async == false) << "Nested async scopes not supported"; + TVM_FFI_ICHECK(in_async == false) << "Nested async scopes not supported"; in_async = true; auto body = this->VisitStmt(attr->body); in_async = false; @@ -53,8 +53,8 @@ class PTXAsyncCopyInjector : public StmtMutator { Stmt InjectPTX(const BufferLoadNode* load, const BufferStoreNode* store, bool predicated = false, PrimExpr predicate_value = PrimExpr()) { if (load->buffer.scope() == "global") { - ICHECK(load->indices.size() == 1 && store->indices.size() == 1); - ICHECK(load->indices[0]->dtype.lanes() == store->indices[0]->dtype.lanes()); + TVM_FFI_ICHECK(load->indices.size() == 1 && store->indices.size() == 1); + TVM_FFI_ICHECK(load->indices[0]->dtype.lanes() == store->indices[0]->dtype.lanes()); const int indices_lanes = load->indices[0]->dtype.lanes(); const int bytes = indices_lanes * load->buffer->dtype.bytes(); @@ -62,15 +62,15 @@ class PTXAsyncCopyInjector : public StmtMutator { if (bytes == 4 || bytes == 8 || bytes == 16) { auto dst_elem_type = GetPointerType(store->buffer->data->type_annotation); auto src_elem_type = GetPointerType(load->buffer->data->type_annotation); - ICHECK(dst_elem_type.has_value() && src_elem_type.has_value()) + TVM_FFI_ICHECK(dst_elem_type.has_value() && src_elem_type.has_value()) << "Both store and load buffer should have a pointer type annotation."; int index_factor = 1; if (dst_elem_type.value() != src_elem_type.value()) { // The only case where src and dst have different dtypes is when the dst shared memory // is a byte buffer generated by merging dynamic shared memory. - ICHECK(store->buffer.scope() == "shared.dyn"); - ICHECK(dst_elem_type.value() == DataType::UInt(8)); + TVM_FFI_ICHECK(store->buffer.scope() == "shared.dyn"); + TVM_FFI_ICHECK(dst_elem_type.value() == DataType::UInt(8)); // BufferStore/Load have the "pointer reinterpret" semantics according to their // "value" dtype. Their "indices" are supposed to be applied after such pointer cast, // for example: ((*float16)(byte_buffer))[buffer->indices] = fp16_value; diff --git a/src/s_tir/transform/inject_software_pipeline.cc b/src/s_tir/transform/inject_software_pipeline.cc index dc21dab7de35..9330f49c4b3b 100644 --- a/src/s_tir/transform/inject_software_pipeline.cc +++ b/src/s_tir/transform/inject_software_pipeline.cc @@ -143,7 +143,7 @@ class PipelineOpaqueAccessRewriter { private: int GetWmmaFragmentSize(const Buffer& buffer) { auto it = fragment_info_.find(buffer->data.get()); - ICHECK(it != fragment_info_.end()); + TVM_FFI_ICHECK(it != fragment_info_.end()); const FragmentInfo& info = (*it).second; return info.GetSize(); } @@ -183,7 +183,7 @@ class PipelineOpaqueAccessRewriter { if (buffer.scope() == "m16n8k8.matrixA" || buffer.scope() == "m16n8k8.matrixB") { // mma scope size will shrink by warp size // @see transform_mma_buffer_layout - ICHECK_EQ(Downcast(floormod(offset, 32))->value, 0) + TVM_FFI_ICHECK_EQ(Downcast(floormod(offset, 32))->value, 0) << "mma scope size should be multiple of warp size"; offset = floordiv(offset, 32); } @@ -440,7 +440,7 @@ class PipelineRewriter : public StmtExprMutator { * \return Whether region1 and region2 have intersections. */ bool MayConflict(Region region1, Region region2) { - ICHECK(region1.size() == region2.size()); + TVM_FFI_ICHECK(region1.size() == region2.size()); for (size_t i = 0; i < region1.size(); i++) { Range dim1 = region1[i]; Range dim2 = region2[i]; @@ -533,7 +533,7 @@ class PipelineRewriter : public StmtExprMutator { ObjectPtr new_buffer = ffi::make_object(*(buffer.get())); new_buffer->shape.insert(new_buffer->shape.begin(), PrimExpr(num_versions)); if (new_buffer->strides.size()) { - ICHECK(new_buffer->strides.size() + 1 == new_buffer->shape.size()); + TVM_FFI_ICHECK(new_buffer->strides.size() + 1 == new_buffer->shape.size()); PrimExpr stride_0 = new_buffer->strides[0] * new_buffer->shape[1]; new_buffer->strides.insert(new_buffer->strides.begin(), stride_0); } @@ -620,7 +620,7 @@ class PipelineRewriter : public StmtExprMutator { for (auto kv : async_states) { if (kv.first <= new_blocks[i].stage && kv.second.writes(read_region->buffer)) { // Found an earlier stage where read_region->buffer was asynchronously written - ICHECK(producer_stage_idx == -1 || producer_stage_idx == kv.first) + TVM_FFI_ICHECK(producer_stage_idx == -1 || producer_stage_idx == kv.first) << "A dependency on multiple async stages is not supported"; producer_stage_idx = kv.first; } @@ -678,10 +678,10 @@ class PipelineRewriter : public StmtExprMutator { if (num_commit_group == 0) { // Epilogue, no async producer. Since "local" producer_head is not available, use // "global" producer_head. - ICHECK(!dep_local_state.producer_head); + TVM_FFI_ICHECK(!dep_local_state.producer_head); producer_head_per_commit.push_back(async_states[producer_stage_idx].producer_head); } else { - ICHECK(dep_local_state.producer_head); + TVM_FFI_ICHECK(dep_local_state.producer_head); std::vector need_wait_count(num_commit_group, true); for (auto read_region : new_blocks[i].block->reads) { @@ -741,7 +741,7 @@ class PipelineRewriter : public StmtExprMutator { if (!state.commit_groups.empty()) { for (size_t i = 0; i < state.commit_groups.size(); ++i) { for (size_t j = 0; j < state.commit_groups[i].size(); ++j) { - ICHECK(state.commit_groups[i][0] + j < new_blocks.size()); + TVM_FFI_ICHECK(state.commit_groups[i][0] + j < new_blocks.size()); commit_group_indices[state.commit_groups[i][0] + j] = stage_id; } } @@ -783,7 +783,7 @@ class PipelineRewriter : public StmtExprMutator { auto stage_id = commit_group_indices[i]; auto predicate = new_blocks[i].predicate; for (; i < commit_group_indices.size() && commit_group_indices[i] == stage_id; ++i) { - ICHECK(tvm::StructuralEqual()(predicate, new_blocks[i].predicate)) + TVM_FFI_ICHECK(tvm::StructuralEqual()(predicate, new_blocks[i].predicate)) << "Predicates in the same stage are expected to be identical"; group_bodies.push_back(new_blocks[i].block->body); } @@ -1050,8 +1050,8 @@ class PipelineInjector : private StmtExprMutator { for (const SBlock& block : original_order) { const auto& stmt_info = pipeline_info.at(block); int order = stmt_info.order; - CHECK(!used_orders.count(order)) - << "ValueError: Two statements in the software pipeline cannot have the same order"; + TVM_FFI_CHECK(!used_orders.count(order), ValueError) + << "Two statements in the software pipeline cannot have the same order"; used_orders.insert(order); } @@ -1064,13 +1064,14 @@ class PipelineInjector : private StmtExprMutator { const ffi::Array& dsts = pair.second; for (const SBlock& dst : dsts) { const auto& dst_info = pipeline_info.at(dst); - CHECK_LE(src_info.stage, dst_info.stage) - << "ValueError: statement " << dst << " in stage " << dst_info.stage + TVM_FFI_CHECK_LE(src_info.stage, dst_info.stage, ValueError) + << "statement " << dst << " in stage " << dst_info.stage << " cannot depends on statement " << src << " in a later stage " << src_info.stage; if (src_info.stage == dst_info.stage) { - CHECK_LT(src_info.order, dst_info.order) << "ValueError: two statements with buffer " - "access dependency in the same stage of the " - "software pipeline cannot be reordered"; + TVM_FFI_CHECK_LT(src_info.order, dst_info.order, ValueError) + << "two statements with buffer " + "access dependency in the same stage of the " + "software pipeline cannot be reordered"; } } } @@ -1090,7 +1091,7 @@ class PipelineInjector : private StmtExprMutator { if (const auto* realize = for_node->body.as()) { const auto& block = realize->block; for (const auto& buffer : block->alloc_buffers) { - ICHECK(buffer->IsInstance()); + TVM_FFI_ICHECK(buffer->IsInstance()); buffer_data_to_buffer_.Set(buffer->data, buffer); } pipeline_body = block->body; @@ -1100,8 +1101,8 @@ class PipelineInjector : private StmtExprMutator { } const SeqStmtNode* pipeline_body_seq = pipeline_body.as(); - CHECK(pipeline_body_seq) - << "ValueError: The body of the software pipeline should be SeqStmt, got " + TVM_FFI_CHECK(pipeline_body_seq, ValueError) + << "The body of the software pipeline should be SeqStmt, got " << pipeline_body->GetTypeKey(); // Step 3: Blockize the components of the pipeline. Each child of the pipelined loop will be @@ -1117,7 +1118,7 @@ class PipelineInjector : private StmtExprMutator { if (nested_block_realize && is_one(nested_block_realize->predicate) && nested_block_realize->block->body->IsInstance()) { const SBlock& nested_pipeline_block = nested_block_realize->block; - ICHECK( + TVM_FFI_ICHECK( nested_pipeline_block->match_buffers.empty()); // match_buffer should have been lowered for (const auto& buffer : nested_pipeline_block->alloc_buffers) { pipeline_allocs.push_back(buffer); @@ -1136,11 +1137,11 @@ class PipelineInjector : private StmtExprMutator { Downcast>(op->annotations.at(tir::attr::software_pipeline_stage)); auto pipeline_orders = Downcast>(op->annotations.at(tir::attr::software_pipeline_order)); - CHECK_EQ(pipeline_stages.size(), original_order.size()) + TVM_FFI_ICHECK_EQ(pipeline_stages.size(), original_order.size()) << "PrimFunc " << global_symbol_ << " has original order " << original_order.Map([](const auto& block) { return block->name_hint; }) << ", but pipeline annotation is " << pipeline_stages << " with different size"; - CHECK_EQ(pipeline_orders.size(), original_order.size()) + TVM_FFI_ICHECK_EQ(pipeline_orders.size(), original_order.size()) << "PrimFunc " << global_symbol_ << " has original order " << original_order.Map([](const auto& block) { return block->name_hint; }) << ", but pipeline annotation is " << pipeline_orders << " with different size"; @@ -1212,8 +1213,9 @@ class PipelineInjector : private StmtExprMutator { auto it = op->annotations.find(tir::attr::double_buffer_scope); if (it != op->annotations.end()) { int buffer_index = Downcast((*it).second).IntValue(); - CHECK(buffer_index >= 0 && static_cast(buffer_index) < op->writes.size()) - << "ValueError: Index of the buffer exceeds the size of the write regions of the block. (" + TVM_FFI_CHECK(buffer_index >= 0 && static_cast(buffer_index) < op->writes.size(), + ValueError) + << "Index of the buffer exceeds the size of the write regions of the block. (" << buffer_index << " vs. " << op->writes.size() << ")"; double_buffers.insert(op->writes[buffer_index]->buffer); } @@ -1234,10 +1236,10 @@ class PipelineInjector : private StmtExprMutator { return true; } if (has_stage) { - LOG(FATAL) << "ValueError: Order of the software pipeline is not defined."; + TVM_FFI_THROW(ValueError) << "Order of the software pipeline is not defined."; } if (has_order) { - LOG(FATAL) << "ValueError: Stage of the software pipeline is not defined."; + TVM_FFI_THROW(ValueError) << "Stage of the software pipeline is not defined."; } return false; } diff --git a/src/s_tir/transform/inject_virtual_thread.cc b/src/s_tir/transform/inject_virtual_thread.cc index 3ff7eeb7b985..b644dffe8d6f 100644 --- a/src/s_tir/transform/inject_virtual_thread.cc +++ b/src/s_tir/transform/inject_virtual_thread.cc @@ -61,8 +61,8 @@ class ExprTouched final : public StmtExprVisitor { if (op->op.same_as(builtin::tvm_access_ptr())) { const auto* rw_mask = op->args[4].as(); const VarNode* buffer_var = op->args[1].as(); - ICHECK(buffer_var); - ICHECK(rw_mask); + TVM_FFI_ICHECK(buffer_var); + TVM_FFI_ICHECK(rw_mask); // read if (rw_mask->value & 1) { HandleUseVar(buffer_var); @@ -192,7 +192,7 @@ class VTInjector : public arith::IRMutatorWithAnalyzer { allow_share_(allow_share) {} // Inject VTLoop when needed. Stmt VisitStmt(const Stmt& s) final { - ICHECK(!visit_touched_var_); + TVM_FFI_ICHECK(!visit_touched_var_); auto stmt = StmtExprMutator::VisitStmt(s); if (visit_touched_var_ || trigger_base_inject_) { if (!vt_loop_injected_) { @@ -205,7 +205,7 @@ class VTInjector : public arith::IRMutatorWithAnalyzer { } // Variable PrimExpr VisitExpr_(const VarNode* op) final { - ICHECK(!alloc_remap_.count(op)) << "Buffer address may get rewritten in virtual thread"; + TVM_FFI_ICHECK(!alloc_remap_.count(op)) << "Buffer address may get rewritten in virtual thread"; if (touched_var_.count(op)) { visit_touched_var_ = true; } @@ -217,7 +217,7 @@ class VTInjector : public arith::IRMutatorWithAnalyzer { // Expression. PrimExpr VisitExpr_(const CallNode* op) final { if (op->op.same_as(builtin::tvm_access_ptr())) { - ICHECK_EQ(op->args.size(), 5U); + TVM_FFI_ICHECK_EQ(op->args.size(), 5U); DataType dtype = op->args[0].dtype(); const VarNode* buffer = op->args[1].as(); auto it = alloc_remap_.find(buffer); @@ -259,7 +259,7 @@ class VTInjector : public arith::IRMutatorWithAnalyzer { auto it = alloc_remap_.find(node->buffer->data.get()); if (it != alloc_remap_.end()) { - ICHECK_EQ(node->indices.size(), 1) + TVM_FFI_ICHECK_EQ(node->indices.size(), 1) << "InjectVirtualThread expects rewritten allocations to be flat memory."; auto writer = node.CopyOnWrite(); writer->buffer = GetRemappedBuffer(node->buffer, it->second); @@ -276,7 +276,8 @@ class VTInjector : public arith::IRMutatorWithAnalyzer { return it->second; } - ICHECK_EQ(buf->shape.size(), 1) << "Expected buffers being rewritten to already be flattened."; + TVM_FFI_ICHECK_EQ(buf->shape.size(), 1) + << "Expected buffers being rewritten to already be flattened."; auto writer = buf.CopyOnWrite(); writer->shape = {buf->shape[0] * alloc_extent}; @@ -318,7 +319,7 @@ class VTInjector : public arith::IRMutatorWithAnalyzer { } // For Stmt VisitStmt_(const ForNode* op) final { - ICHECK(is_zero(op->min)); + TVM_FFI_ICHECK(is_zero(op->min)); PrimExpr extent = this->VisitExpr(op->extent); if (visit_touched_var_ && !vt_loop_injected_) { Stmt stmt = InjectVTLoop(ffi::GetRef(op), true); @@ -344,7 +345,7 @@ class VTInjector : public arith::IRMutatorWithAnalyzer { return InjectVTLoop(ffi::GetRef(op), true); } visit_touched_var_ = false; - ICHECK_EQ(max_loop_depth_, 0); + TVM_FFI_ICHECK_EQ(max_loop_depth_, 0); Stmt then_case = this->VisitStmt(op->then_case); ffi::Optional else_case = std::nullopt; if (op->else_case) { @@ -364,12 +365,12 @@ class VTInjector : public arith::IRMutatorWithAnalyzer { // While Stmt VisitStmt_(const WhileNode* op) final { // TODO(masahi): What should we do for While nodes? - LOG(FATAL) << "WhileNode in InjectVirtualThread not supported yet"; + TVM_FFI_THROW(InternalError) << "WhileNode in InjectVirtualThread not supported yet"; } // Seq Stmt VisitStmt_(const SeqStmtNode* op) final { - ICHECK_EQ(max_loop_depth_, 0); + TVM_FFI_ICHECK_EQ(max_loop_depth_, 0); auto fmutate = [this](const Stmt& s) { int temp = max_loop_depth_; max_loop_depth_ = 0; @@ -404,7 +405,7 @@ class VTInjector : public arith::IRMutatorWithAnalyzer { // TODO(Lunderberg): Move pass to apply before // FlattenBuffer. Would rewrite the Buffer to // add the injected virtual thread as the first index. - ICHECK_EQ(extents.size(), 1) + TVM_FFI_ICHECK_EQ(extents.size(), 1) << "InjectVirtualThread expects rewritten allocations to be flat memory."; PrimExpr stride = extents[0]; extents = {stride * num_threads_}; @@ -427,7 +428,7 @@ class VTInjector : public arith::IRMutatorWithAnalyzer { // inject vthread loop Stmt InjectVTLoop(Stmt stmt, bool before_mutation) { - ICHECK(!vt_loop_injected_); + TVM_FFI_ICHECK(!vt_loop_injected_); // reset the flags visit_touched_var_ = false; trigger_base_inject_ = false; diff --git a/src/s_tir/transform/loop_partition.cc b/src/s_tir/transform/loop_partition.cc index 9d2809103f51..8020e97867ce 100644 --- a/src/s_tir/transform/loop_partition.cc +++ b/src/s_tir/transform/loop_partition.cc @@ -134,7 +134,7 @@ class CandidateSelector final : public StmtExprVisitor { void VisitStmt_(const AttrStmtNode* op) final { if (op->attr_key == tir::attr::thread_extent) { const IterVarNode* iv = op->node.as(); - ICHECK(iv); + TVM_FFI_ICHECK(iv); Var var = iv->var; // always treat var with hint to be partitioned if (partition_hint_vars.count(var.get())) { @@ -160,7 +160,7 @@ class CandidateSelector final : public StmtExprVisitor { } else if (op->node.as()) { var = op->node.as()->var.get(); } - ICHECK(var); + TVM_FFI_ICHECK(var); partition_hint_vars.insert(var); } } @@ -256,7 +256,7 @@ class PartitionFinder : public StmtExprVisitor { // handle thread_axis if (op->attr_key == tir::attr::thread_extent) { const IterVarNode* thread_axis = op->node.as(); - ICHECK(thread_axis); + TVM_FFI_ICHECK(thread_axis); const VarNode* var = thread_axis->var.get(); IntSet dom = IntSet::FromRange(Range(make_zero(op->value.dtype()), op->value)); hint_map_.insert({var, dom}); @@ -442,7 +442,7 @@ class LoopPartitioner : public StmtMutator { } const IterVarNode* iv = op->node.as(); - ICHECK(iv); + TVM_FFI_ICHECK(iv); Var var = iv->var; auto as = ffi::GetRef(op); if (selector.candidates.count(as)) { @@ -760,14 +760,14 @@ Stmt LoopPartitioner::TryPartition(const Stmt& stmt, Var var, PrimExpr min, Prim inline Stmt LoopPartitioner::MakeFor(const Object* node, PrimExpr extent, Stmt body) { const ForNode* for_node = static_cast(node); - ICHECK(for_node); + TVM_FFI_ICHECK(for_node); if (analyzer_.CanProve(extent == make_const(DataType::Int(32), 1)) && !no_unroll_loop_with_extent_one_ && for_node->annotations.empty()) { // If the loop extent is 1, do not create the loop anymore return Substitute(body, {{Var{for_node->loop_var}, make_const(DataType::Int(32), 0)}}); } else { - ICHECK(for_node->kind != ForKind::kThreadBinding); + TVM_FFI_ICHECK(for_node->kind != ForKind::kThreadBinding); auto new_loop = ffi::make_object(*for_node); new_loop->min = IntImm(for_node->min.dtype(), 0); new_loop->extent = extent; @@ -780,10 +780,10 @@ class RemoveLikelyTagsAndHints : public StmtExprMutator { public: PrimExpr VisitExpr_(const CallNode* op) final { if (op->op.same_as(builtin::likely())) { - ICHECK_EQ(op->args.size(), 1); + TVM_FFI_ICHECK_EQ(op->args.size(), 1); return StmtExprMutator::VisitExpr(op->args[0]); } else if (op->op.same_as(builtin::ignore_loop_partition())) { - ICHECK_EQ(op->args.size(), 1); + TVM_FFI_ICHECK_EQ(op->args.size(), 1); return StmtExprMutator::VisitExpr(op->args[0]); } else { return StmtExprMutator::VisitExpr_(op); diff --git a/src/s_tir/transform/lower_async_dma.cc b/src/s_tir/transform/lower_async_dma.cc index 8d95c7691890..d00824eed467 100644 --- a/src/s_tir/transform/lower_async_dma.cc +++ b/src/s_tir/transform/lower_async_dma.cc @@ -96,7 +96,7 @@ class AsyncDMALowerer : public arith::IRMutatorWithAnalyzer { if (op->attr_key == tir::attr::async_wait_queue_scope) { // get queue ID auto queue_id_node = op->value.as(); - ICHECK(queue_id_node); + TVM_FFI_ICHECK(queue_id_node); int queue_id = queue_id_node->value; // abort if we have not seen this queue ID in `copy` transform @@ -138,7 +138,7 @@ class AsyncDMALowerer : public arith::IRMutatorWithAnalyzer { } else if (op->attr_key == tir::attr::async_commit_queue_scope) { // get queue ID auto queue_id_node = op->value.as(); - ICHECK(queue_id_node); + TVM_FFI_ICHECK(queue_id_node); async_queue_id_ = queue_id_node->value; auto result = arith::IRMutatorWithAnalyzer::VisitStmt_(op); if (dmas_in_group_ > 1) { diff --git a/src/s_tir/transform/lower_cross_thread_reduction.cc b/src/s_tir/transform/lower_cross_thread_reduction.cc index f442841256f3..c3c1f2ab3a0a 100644 --- a/src/s_tir/transform/lower_cross_thread_reduction.cc +++ b/src/s_tir/transform/lower_cross_thread_reduction.cc @@ -87,7 +87,7 @@ bool IsDominantBlock(const SBlock& scope_block, const SBlock& block) { }); // Step 2. Check whether `block` is the only writer of its outputs. for (const BufferRegion& buffer_region : block->writes) { - ICHECK(buffer_writer_cnt.count(buffer_region->buffer.get())); + TVM_FFI_ICHECK(buffer_writer_cnt.count(buffer_region->buffer.get())); if (buffer_writer_cnt[buffer_region->buffer.get()] != 1) { return false; } @@ -166,7 +166,7 @@ class BufferReplacer : private StmtExprMutator { public: static Stmt Run(ffi::Array src_buffers, ffi::Array tgt_buffers, Stmt stmt) { ffi::Map buffer_map; - ICHECK_EQ(src_buffers.size(), tgt_buffers.size()); + TVM_FFI_ICHECK_EQ(src_buffers.size(), tgt_buffers.size()); int n_buffers = src_buffers.size(); for (int i = 0; i < n_buffers; ++i) { buffer_map.Set(src_buffers[i], tgt_buffers[i]); @@ -427,7 +427,7 @@ Stmt TransformReductionBlock(const SBlockRealizeNode* realize, } // Stmt 4: write cross-thread reduction result to the original buffer { - ICHECK_EQ(block->iter_vars.size(), realize->iter_values.size()); + TVM_FFI_ICHECK_EQ(block->iter_vars.size(), realize->iter_values.size()); int n_iter = static_cast(block->iter_vars.size()); ffi::Array iter_vars; ffi::Array bindings; @@ -649,8 +649,8 @@ class CrossThreadReductionTransformer : public StmtMutator { } ++n_deepest_reduction_loops; } - CHECK_EQ(n_deepest_reduction_loops, reduction_loops.size()) - << "ValueError: Cross-thread reduction requires all the reduction-related loops to be the " + TVM_FFI_CHECK_EQ(n_deepest_reduction_loops, reduction_loops.size(), ValueError) + << "Cross-thread reduction requires all the reduction-related loops to be the " "deepest among all statements outside the desired block. However, block " << block->name_hint << " needs cross-thread reduction, while the reduction-related loops outside of it are not " @@ -662,8 +662,8 @@ class CrossThreadReductionTransformer : public StmtMutator { for (const ForNode* reduction_loop : reduction_loops) { if (reduction_loop->thread_binding.defined()) { ++n_bound_reduction_loops; - CHECK(IsBoundToThreadIdx(reduction_loop)) - << "ValueError: Cross-thread reduction requires all the reduction-related loops that " + TVM_FFI_CHECK(IsBoundToThreadIdx(reduction_loop), ValueError) + << "Cross-thread reduction requires all the reduction-related loops that " "are bound to GPU thread axes to only be bound `threadIdx.x/y/z`. However, loop " << reduction_loop->loop_var->name_hint << " violates the condition."; } @@ -689,14 +689,14 @@ class CrossThreadReductionTransformer : public StmtMutator { for (const BufferStore& buf_store : updates) { reduction_buffers.push_back(buf_store->buffer); if (buf_store->buffer.scope() == "local") { - CHECK_NE(is_local_buf, 0) - << "ValueError: Cross-thread reduction requires all reduction buffers to be all " + TVM_FFI_CHECK_NE(is_local_buf, 0, ValueError) + << "Cross-thread reduction requires all reduction buffers to be all " "local or all non-local. However, here some buffer is local while some buffer is " "shared or global."; is_local_buf = 1; } else { - CHECK_NE(is_local_buf, 1) - << "ValueError: Cross-thread reduction requires all reduction buffers to be all " + TVM_FFI_CHECK_NE(is_local_buf, 1, ValueError) + << "Cross-thread reduction requires all reduction buffers to be all " "local or all non-local. However, here some buffer is local while some buffer is " "shared or global."; is_local_buf = 0; @@ -707,8 +707,9 @@ class CrossThreadReductionTransformer : public StmtMutator { bool visit = false; PreOrderVisit(ffi::GetRef(reduction_loops[0]), [block, &visit](const ObjectRef& obj) { if (const auto* realize = obj.as()) { - CHECK(!visit) << "ValueError: Cross-thread reduction cannot be applied when the reduction " - "block isn't the last block under its first reduction-related loop"; + TVM_FFI_CHECK(!visit, ValueError) + << "Cross-thread reduction cannot be applied when the reduction " + "block isn't the last block under its first reduction-related loop"; if (realize->block.get() == block) { visit = true; } diff --git a/src/s_tir/transform/lower_match_buffer.cc b/src/s_tir/transform/lower_match_buffer.cc index 5c324eecb7e4..7c0b2e81bc5b 100644 --- a/src/s_tir/transform/lower_match_buffer.cc +++ b/src/s_tir/transform/lower_match_buffer.cc @@ -52,7 +52,7 @@ class MatchBufferLower : public StmtExprMutator { Stmt stmt = StmtExprMutator ::VisitStmt_(op); op = stmt.as(); - ICHECK(op != nullptr); + TVM_FFI_ICHECK(op != nullptr); ffi::Array reads = op->reads.Map(std::bind(&MatchBufferLower::VisitBufferRegion, this, std::placeholders::_1)); ffi::Array writes = op->writes.Map( @@ -87,7 +87,7 @@ class MatchBufferLower : public StmtExprMutator { Stmt VisitStmt_(const BufferStoreNode* op) final { Stmt stmt = StmtExprMutator::VisitStmt_(op); op = stmt.as(); - ICHECK(op != nullptr); + TVM_FFI_ICHECK(op != nullptr); auto it = match_buffers_.find(op->buffer); if (it == match_buffers_.end()) { @@ -99,7 +99,7 @@ class MatchBufferLower : public StmtExprMutator { auto n = CopyOnWrite(op); n->indices = ConvertIndices(MatchBufferRegion(buffer, source), op->indices); n->buffer = source->buffer; - ICHECK(!op->predicate.defined()) + TVM_FFI_ICHECK(!op->predicate.defined()) << "Predicated buffer store is not currently supported in lower match buffer pass."; return Stmt(n); } @@ -108,7 +108,7 @@ class MatchBufferLower : public StmtExprMutator { PrimExpr VisitExpr_(const BufferLoadNode* op) final { PrimExpr expr = StmtExprMutator::VisitExpr_(op); op = expr.as(); - ICHECK(op != nullptr); + TVM_FFI_ICHECK(op != nullptr); auto it = match_buffers_.find(op->buffer); if (it == match_buffers_.end()) { @@ -117,7 +117,7 @@ class MatchBufferLower : public StmtExprMutator { const Buffer& buffer = (*it).first; const BufferRegion& source = (*it).second; ffi::Array indices = ConvertIndices(MatchBufferRegion(buffer, source), op->indices); - ICHECK(!op->predicate.defined()) + TVM_FFI_ICHECK(!op->predicate.defined()) << "Predicated buffer load is not currently supported in lower match buffer pass."; return BufferLoad(source->buffer, indices); } @@ -143,10 +143,10 @@ class MatchBufferLower : public StmtExprMutator { const Buffer& source_buffer = source->buffer; // Step.1.1. Check scope & dtype - ICHECK_EQ(buffer.scope(), source_buffer.scope()) + TVM_FFI_ICHECK_EQ(buffer.scope(), source_buffer.scope()) << "MatchBuffer " << buffer << " scope mismatch:" << buffer.scope() << "vs." << source_buffer.scope(); - ICHECK_EQ(buffer->dtype, source_buffer->dtype) + TVM_FFI_ICHECK_EQ(buffer->dtype, source_buffer->dtype) << "MatchBuffer " << buffer << " data type mismatch:" << buffer->dtype << "vs." << source_buffer->dtype; @@ -157,7 +157,7 @@ class MatchBufferLower : public StmtExprMutator { << ", provided alignment=" << source_buffer->data_alignment; } if (is_zero(buffer->elem_offset)) { - ICHECK(is_zero(source_buffer->elem_offset)) + TVM_FFI_ICHECK(is_zero(source_buffer->elem_offset)) << "Trying to bind a Buffer with offset into one without offset " << " required elem_offset=" << buffer->elem_offset << ", provided elem_offset=" << source_buffer->elem_offset; @@ -180,7 +180,8 @@ class MatchBufferLower : public StmtExprMutator { ffi::Array buffer_start_indices = source_buffer->ElemOffset(indices); if (buffer_start_indices.size() == 1) { Bind(buffer->elem_offset, buffer_start_indices[0], buffer->name + ".elem_offset"); - CHECK(analyzer_.CanProve(truncmod(buffer->elem_offset, buffer->offset_factor) == 0)) + TVM_FFI_ICHECK( + analyzer_.CanProve(truncmod(buffer->elem_offset, buffer->offset_factor) == 0)) << "The source elem_offset " << buffer_start_indices[0] << " does not satisfy the offset_factor " << buffer->offset_factor << "."; } else { @@ -193,10 +194,10 @@ class MatchBufferLower : public StmtExprMutator { // Step 2.3. Check and update strides // Check if target buffer strides are defined - ICHECK(source->region.size() >= buffer->shape.size()); + TVM_FFI_ICHECK(source->region.size() >= buffer->shape.size()); int offset = source->region.size() - buffer->shape.size(); if (!buffer->strides.empty()) { - ICHECK_EQ(buffer->strides.size(), buffer->shape.size()); + TVM_FFI_ICHECK_EQ(buffer->strides.size(), buffer->shape.size()); if (source_buffer->strides.empty()) { PrimExpr stride = make_const(buffer->strides.back().dtype(), 1); for (size_t i = buffer->shape.size(); i > 0; --i) { @@ -205,7 +206,7 @@ class MatchBufferLower : public StmtExprMutator { stride *= shape; } } else { - ICHECK_EQ(buffer->shape.size() + offset, source_buffer->strides.size()); + TVM_FFI_ICHECK_EQ(buffer->shape.size() + offset, source_buffer->strides.size()); for (size_t i = buffer->shape.size(); i > 0; --i) { const PrimExpr& stride = source_buffer->strides[i - 1 + offset]; Bind(buffer->strides[i - 1], stride, buffer->name + ".strides_" + std::to_string(i - 1)); @@ -226,7 +227,7 @@ class MatchBufferLower : public StmtExprMutator { arg.dtype().lanes() == value.dtype().lanes()) { value = cast(arg.dtype(), value); } else { - CHECK_EQ(arg.dtype(), value.dtype()) + TVM_FFI_ICHECK_EQ(arg.dtype(), value.dtype()) << "The data type mismatched: " << arg->dtype << " vs. " << value->dtype; } } @@ -248,8 +249,8 @@ class MatchBufferLower : public StmtExprMutator { void AssertBinding(const PrimExpr& lhs, const PrimExpr& rhs, const std::string& arg_name = "argument") { - CHECK(analyzer_.CanProve(lhs == rhs)) << "The buffer match constraint for " << arg_name - << " unmet: " << lhs << "==" << rhs << "."; + TVM_FFI_ICHECK(analyzer_.CanProve(lhs == rhs)) << "The buffer match constraint for " << arg_name + << " unmet: " << lhs << "==" << rhs << "."; } private: diff --git a/src/s_tir/transform/lower_opaque_block.cc b/src/s_tir/transform/lower_opaque_block.cc index cc3ad765ebc6..f06d70bc5336 100644 --- a/src/s_tir/transform/lower_opaque_block.cc +++ b/src/s_tir/transform/lower_opaque_block.cc @@ -45,8 +45,9 @@ class OpaqueBlockLower : public StmtExprMutator { private: Stmt VisitStmt_(const SBlockRealizeNode* op) final { // We have convert blocks into opaque blocks in previous passes. - ICHECK(op->iter_values.empty()) << "Non-opaque blocks are not allowed in FlattenBuffer. Please " - "call pass ConvertBlocksToOpaque before."; + TVM_FFI_ICHECK(op->iter_values.empty()) + << "Non-opaque blocks are not allowed in FlattenBuffer. Please " + "call pass ConvertBlocksToOpaque before."; // Step 1. Visit the body SBlock new_block = Downcast(this->VisitStmt(op->block)); PrimExpr predicate = this->VisitExpr(op->predicate); @@ -102,7 +103,7 @@ class OpaqueBlockLower : public StmtExprMutator { // Step 4. Create new For loop accordingly if (op->kind == ForKind::kThreadBinding) { // Case 1. Thread binding - ICHECK(op->thread_binding.defined()); + TVM_FFI_ICHECK(op->thread_binding.defined()); ffi::String thread_tag = op->thread_binding.value()->thread_tag; body = MakeLaunchThread(min, extent, op->loop_var, thread_tag, body); } else if (is_one(extent) && op->annotations.empty() && @@ -161,8 +162,8 @@ class OpaqueBlockLower : public StmtExprMutator { } else if (auto str = obj.try_cast()) { return std::move(StringImm(str.value())); } else { - LOG(FATAL) << "Illegal attribute of key " << key << ", value type " << obj.GetTypeKey() - << " not supported"; + TVM_FFI_THROW(InternalError) << "Illegal attribute of key " << key << ", value type " + << obj.GetTypeKey() << " not supported"; return PrimExpr(); } } diff --git a/src/s_tir/transform/lower_thread_allreduce.cc b/src/s_tir/transform/lower_thread_allreduce.cc index 5f1fc9afaa40..39ccef472139 100644 --- a/src/s_tir/transform/lower_thread_allreduce.cc +++ b/src/s_tir/transform/lower_thread_allreduce.cc @@ -55,7 +55,7 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { return ret; } else if (op->attr_key == tir::attr::reduce_scope) { const CommReducerNode* combiner = op->node.as(); - ICHECK(combiner); + TVM_FFI_ICHECK(combiner); reduce_combiner_.push_back(combiner); Stmt ret = StmtExprMutator::VisitStmt_(op); reduce_combiner_.pop_back(); @@ -119,7 +119,7 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { PrimExpr VisitExpr_(const BufferLoadNode* op) final { if (auto it = load_remap_.find(op->buffer->data.get()); it != load_remap_.end()) { for (const auto& index : op->indices) { - ICHECK(is_zero(index)); + TVM_FFI_ICHECK(is_zero(index)); } return it->second; } @@ -156,13 +156,13 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { // make allreduce. Stmt MakeAllreduce(const CallNode* call) { - ICHECK(!reduce_combiner_.empty()); + TVM_FFI_ICHECK(!reduce_combiner_.empty()); const CommReducerNode* combiner = reduce_combiner_.back(); size_t size = combiner->result.size(); const IntImmNode* size_of_args = call->args[0].as(); - ICHECK(size_of_args) << call->args[0]->GetTypeKey(); - ICHECK_EQ(size, size_of_args->value); + TVM_FFI_ICHECK(size_of_args) << call->args[0]->GetTypeKey(); + TVM_FFI_ICHECK_EQ(size, size_of_args->value); ffi::Array inits = combiner->identity_element; std::vector values(size); std::vector types(size); @@ -194,7 +194,7 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { if (v) { reduce_set.insert(v); } else { - ICHECK(call->args[i].as() && call->args[i].as()->value == 0) + TVM_FFI_ICHECK(call->args[i].as() && call->args[i].as()->value == 0) << "arg" << i << "should be a VarNode or IntImmNode"; } } @@ -206,11 +206,11 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { IterVar iv = Downcast(attr->node); e.scope = runtime::ThreadScope::Create(iv->thread_tag); e.iv = iv; - ICHECK_LE(e.scope.rank, 1); - ICHECK_GE(e.scope.dim_index, 0) << "vthread do not work with cross thread reduction"; + TVM_FFI_ICHECK_LE(e.scope.rank, 1); + TVM_FFI_ICHECK_GE(e.scope.dim_index, 0) << "vthread do not work with cross thread reduction"; if (e.scope.rank == 1) { const auto* ptr = attr->value.as(); - ICHECK(ptr) << "Need constant extent for reduce set " << iv; + TVM_FFI_ICHECK(ptr) << "Need constant extent for reduce set " << iv; e.extent = static_cast(ptr->value); // ignore variables equal to 0 if (e.extent == 1) { @@ -225,7 +225,8 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { } } } - ICHECK_EQ(nmatch, reduce_set.size()) << "Not all reduce index are presented in the context"; + TVM_FFI_ICHECK_EQ(nmatch, reduce_set.size()) + << "Not all reduce index are presented in the context"; std::sort(vred.begin(), vred.end()); std::sort(vpar.begin(), vpar.end()); // the size of each index. @@ -306,7 +307,7 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { for (size_t i = 0; i < size; ++i) { Buffer buf = Downcast(reduce_results[i])->buffer; PrimExpr val = BufferLoad(buf, {zero_index}); - ICHECK_EQ(val->dtype, types[i]); + TVM_FFI_ICHECK_EQ(val->dtype, types[i]); PrimExpr splat = WarpShuffle(builtin::tvm_warp_shuffle(), new_alloc_bufs.back(), val, reduce_extent * group_index); seq.push_back(BufferStore(buf, splat, {zero_index})); @@ -377,10 +378,10 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { // Write back allreduce results and update existing allocations. for (size_t i = 0; i < size; ++i) { - ICHECK(!load_remap_.count(buffers[i]->data.get())); + TVM_FFI_ICHECK(!load_remap_.count(buffers[i]->data.get())); PrimExpr pred = const_true(types[i].lanes()); Buffer buf = Downcast(reduce_results[i])->buffer; - ICHECK_EQ(reduce_results[i]->dtype, types[i]); + TVM_FFI_ICHECK_EQ(reduce_results[i]->dtype, types[i]); load_remap_[buffers[i]->data.get()] = reduce_results[i]; auto node = Allocate(buf->data, types[i], buf->shape, pred, Evaluate(0)); @@ -411,11 +412,11 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { seq.emplace_back(MakeBufAllreduce(combiner, types, shared_bufs, reduce_index, group_index, reduce_extent, group_extent, contiguous_reduce_extent)); for (size_t idx = 0; idx < size; ++idx) { - ICHECK(!load_remap_.count(buffers[idx]->data.get())); + TVM_FFI_ICHECK(!load_remap_.count(buffers[idx]->data.get())); PrimExpr pred = const_true(types[idx].lanes()); BufferLoad load(shared_bufs[idx], {BufIndex(make_zero(reduce_index.dtype()), group_index, reduce_extent)}); - ICHECK_EQ(load->dtype, types[idx]); + TVM_FFI_ICHECK_EQ(load->dtype, types[idx]); load_remap_[buffers[idx]->data.get()] = load; alloc_remap_[buffers[idx]->data.get()] = shared_bufs[idx]; var_remap_[buffers[idx]->data.get()] = shared_bufs[idx]->data; @@ -494,7 +495,7 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { for (int i = 0; i < n_buffers; ++i) { Buffer shared_buf = shared_bufs[i]; BufferLoad val(shared_buf, zero_indices); - ICHECK_EQ(val->dtype, dtypes[i]); + TVM_FFI_ICHECK_EQ(val->dtype, dtypes[i]); a.push_back(val); // __shfl_*sync calls shall not appear in if_then_else expressions @@ -515,7 +516,7 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { seq->push_back(s); BufferLoad load = BufferLoad(local_buf, zero_indices); - ICHECK_EQ(load->dtype, dtypes[i]); + TVM_FFI_ICHECK_EQ(load->dtype, dtypes[i]); b.push_back(load); } @@ -563,7 +564,7 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { while (reduce_extent > reduce_align) { reduce_align = reduce_align << 1; } - ICHECK_GT(reduce_align, 1); + TVM_FFI_ICHECK_GT(reduce_align, 1); std::vector seq; size_t size = shared_bufs.size(); @@ -574,11 +575,11 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { for (size_t i = 0; i < size; ++i) { BufferLoad b_load(shared_bufs[i], {BufIndex(reduce_index + offset, group_index, reduce_extent)}); - ICHECK_EQ(b_load->dtype, types[i]); + TVM_FFI_ICHECK_EQ(b_load->dtype, types[i]); b.push_back(b_load); BufferLoad a_load(shared_bufs[i], {buf_index}); - ICHECK_EQ(a_load->dtype, types[i]); + TVM_FFI_ICHECK_EQ(a_load->dtype, types[i]); a.push_back(a_load); } ffi::Array ret = (*combiner)(a, b); @@ -676,7 +677,7 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { if (ret.defined()) { ret = ret + e.iv->var * total_extent; } else { - ICHECK_EQ(total_extent, 1); + TVM_FFI_ICHECK_EQ(total_extent, 1); ret = e.iv->var; } total_extent *= e.extent; @@ -802,7 +803,7 @@ Pass LowerThreadAllreduce() { auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) { auto* n = f.CopyOnWrite(); auto target = f->GetAttr(tvm::attr::kTarget); - ICHECK(target.defined()) << "LowerThreadAllreduce: Require the target attribute"; + TVM_FFI_ICHECK(target.defined()) << "LowerThreadAllreduce: Require the target attribute"; const TargetNode* target_node = target.as(); ThreadAllreduceBuilder thread_all_reduce(target_node); n->body = thread_all_reduce(n->body); diff --git a/src/s_tir/transform/lower_vtcm_alloc.cc b/src/s_tir/transform/lower_vtcm_alloc.cc index 469f7c465525..a6e05f9ae28b 100644 --- a/src/s_tir/transform/lower_vtcm_alloc.cc +++ b/src/s_tir/transform/lower_vtcm_alloc.cc @@ -54,7 +54,7 @@ class VtcmAllocator : public StmtExprMutator { protected: std::string GetStorageScope(const Var& var) { auto* ptr = var->type_annotation.as(); - ICHECK(ptr) << "Buffer Var's type annotation must be of PointerType"; + TVM_FFI_ICHECK(ptr) << "Buffer Var's type annotation must be of PointerType"; return ptr->storage_scope; } }; diff --git a/src/s_tir/transform/manifest_shared_memory_local_stage.cc b/src/s_tir/transform/manifest_shared_memory_local_stage.cc index 0d409d34d058..6b4089bd299f 100644 --- a/src/s_tir/transform/manifest_shared_memory_local_stage.cc +++ b/src/s_tir/transform/manifest_shared_memory_local_stage.cc @@ -53,9 +53,10 @@ class IntermediateStageRewriter { std::tuple Rewrite(const SBlockNode* block) { const BufferStoreNode* store = block->body.as(); - CHECK(store != nullptr && runtime::StorageScope::Create(store->buffer.scope()).rank == - runtime::StorageRank::kShared) - << "ValueError: Expect the body of the block to be BufferStore to shared memory."; + TVM_FFI_CHECK(store != nullptr && runtime::StorageScope::Create(store->buffer.scope()).rank == + runtime::StorageRank::kShared, + ValueError) + << "Expect the body of the block to be BufferStore to shared memory."; const Buffer& target_buffer = store->buffer; @@ -69,8 +70,9 @@ class IntermediateStageRewriter { Stmt local_stage = MakeLocalStage(block, new_buffer, buffer_indices, relaxed_loops, store); // Step 3: Create BufferLoad from the intermediate buffer - ICHECK(!store->predicate.defined()) << "Predicated buffer store is not currently supported in " - "manifest shared memory local stage pass."; + TVM_FFI_ICHECK(!store->predicate.defined()) + << "Predicated buffer store is not currently supported in " + "manifest shared memory local stage pass."; BufferLoad new_buffer_load = BufferLoad(new_buffer, buffer_indices); BufferStore new_buffer_store = Downcast(block->body); new_buffer_store.CopyOnWrite()->value = new_buffer_load; @@ -88,30 +90,30 @@ class IntermediateStageRewriter { for (int n = static_cast(ancestor_loop_or_blocks_.size()) - 1, i = n - 1; i >= 0; --i) { const Stmt& ancestor = ancestor_loop_or_blocks_[i]; if (const ForNode* ancestor_loop = ancestor.as()) { - CHECK(ancestor_loop->kind == ForKind::kSerial || - ancestor_loop->kind == ForKind::kVectorized) - << "ValueError: Expect the ancestor loops to be serial or vectorized, got " - << ancestor_loop->kind; + TVM_FFI_CHECK( + ancestor_loop->kind == ForKind::kSerial || ancestor_loop->kind == ForKind::kVectorized, + ValueError) + << "Expect the ancestor loops to be serial or vectorized, got " << ancestor_loop->kind; relaxed_loops.push_back(ancestor.as()); if (i < n - 1) { - CHECK(ancestor_loop->body.same_as(ancestor_loop_or_blocks_[i + 1])) - << "ValueError: Expect the ancestor loops to have a single child."; + TVM_FFI_CHECK(ancestor_loop->body.same_as(ancestor_loop_or_blocks_[i + 1]), ValueError) + << "Expect the ancestor loops to have a single child."; } else { const SBlockRealizeNode* block_realize = ancestor_loop->body.as(); - ICHECK(block_realize != nullptr); - CHECK(block_realize != nullptr && block_realize->block.get() == block) - << "ValueError: Expect the ancestor loops to have a single child."; + TVM_FFI_ICHECK(block_realize != nullptr); + TVM_FFI_CHECK(block_realize != nullptr && block_realize->block.get() == block, ValueError) + << "Expect the ancestor loops to have a single child."; } } else { const SBlockRealizeNode* ancestor_block_realize = ancestor.as(); - ICHECK(ancestor_block_realize != nullptr); + TVM_FFI_ICHECK(ancestor_block_realize != nullptr); const SBlockNode* ancestor_block = ancestor_block_realize->block.get(); auto it = std::find_if( ancestor_block->alloc_buffers.begin(), ancestor_block->alloc_buffers.end(), [&target_buffer](const Buffer& buffer) { return buffer.same_as(target_buffer); }); - CHECK(it != ancestor_block->alloc_buffers.end()) - << "ValueError: Expect the shared memory allocation to be in the parent block."; + TVM_FFI_CHECK(it != ancestor_block->alloc_buffers.end(), ValueError) + << "Expect the shared memory allocation to be in the parent block."; break; } } diff --git a/src/s_tir/transform/memhammer_coalesce.cc b/src/s_tir/transform/memhammer_coalesce.cc index a575fd3e9626..44a925fda77b 100644 --- a/src/s_tir/transform/memhammer_coalesce.cc +++ b/src/s_tir/transform/memhammer_coalesce.cc @@ -162,8 +162,8 @@ ffi::Array GetMapping(const Stmt& stmt, const ConstraintSet& constrain const BufferStoreNode* buf_store = TVM_TYPE_AS(body, BufferStoreNode); BufferRegion write_region = constraints.write_region; const ffi::Array& write_index = buf_store->indices; - ICHECK(write_region->region.size() == write_index.size() && - write_region->buffer.same_as(buf_store->buffer)); + TVM_FFI_ICHECK(write_region->region.size() == write_index.size() && + write_region->buffer.same_as(buf_store->buffer)); ffi::Array result; arith::Analyzer analyzer; for (int i = 0; i < static_cast(write_region->region.size()); i++) { @@ -192,7 +192,7 @@ Stmt InverseMapping::Rewrite(const Stmt& stmt, const ConstraintSet& constraints, DiagnosticContext diag_ctx(DiagnosticContext::Default(IRModule())); auto iter_map = arith::DetectIterMap(mapping_pattern, var_range, Bool(true), arith::Bijective, &analyzer); - CHECK_EQ(iter_map->indices.size(), loop_vars.size()); + TVM_FFI_ICHECK_EQ(iter_map->indices.size(), loop_vars.size()); ffi::Map inverse_mapping = arith::InverseAffineIterMap(iter_map->indices, loop_vars); // Step 3. Generate new body diff --git a/src/s_tir/transform/memhammer_intermediate_stage.cc b/src/s_tir/transform/memhammer_intermediate_stage.cc index d9116ac6553e..78f6170c56f7 100644 --- a/src/s_tir/transform/memhammer_intermediate_stage.cc +++ b/src/s_tir/transform/memhammer_intermediate_stage.cc @@ -163,7 +163,7 @@ class IndexPatternFinder : public ExprVisitor { } } if (extent > 1) { - ICHECK(max % extent == 0); + TVM_FFI_ICHECK(max % extent == 0); access_shape_.push_back(Integer(extent)); resulting_index_->push_back(floordiv(index, max / extent)); } @@ -286,11 +286,11 @@ std::pair InsertCacheStage(Stmt stmt, bool is_write_cache, ffi::S if (target_buffer_load == nullptr) { target_buffer_load = buffer_load; } else { - CHECK(target_buffer_load->buffer.same_as(buffer_load->buffer)) + TVM_FFI_ICHECK(target_buffer_load->buffer.same_as(buffer_load->buffer)) << "More than one target buffer found"; - ICHECK(target_buffer_load->indices.size() == buffer_load->indices.size()); + TVM_FFI_ICHECK(target_buffer_load->indices.size() == buffer_load->indices.size()); for (size_t i = 0; i < target_buffer_load->indices.size(); i++) { - CHECK( + TVM_FFI_ICHECK( analyzer.CanProveEqual(target_buffer_load->indices[i], buffer_load->indices[i])); } } @@ -298,7 +298,7 @@ std::pair InsertCacheStage(Stmt stmt, bool is_write_cache, ffi::S } return true; }); - CHECK(target_buffer_load); + TVM_FFI_ICHECK(target_buffer_load); } const BufferStoreNode* buf_store = TVM_TYPE_AS(body, BufferStoreNode); @@ -382,7 +382,7 @@ std::pair InsertCacheStage(Stmt stmt, bool is_write_cache, ffi::S if (predicate.defined()) { // generated by coalescing - CHECK_EQ(loops_under_compute_location.size(), 2); + TVM_FFI_ICHECK_EQ(loops_under_compute_location.size(), 2); PrimExpr subst_value = 0; PrimExpr subst_predicate = Substitute(predicate.value(), subst_map); generate_body = IfThenElse(subst_predicate, generate_body); diff --git a/src/s_tir/transform/memhammer_lower_auto_copy.cc b/src/s_tir/transform/memhammer_lower_auto_copy.cc index 7e232e5ebf4f..988702eb96f1 100644 --- a/src/s_tir/transform/memhammer_lower_auto_copy.cc +++ b/src/s_tir/transform/memhammer_lower_auto_copy.cc @@ -670,8 +670,8 @@ class AutoCopyMutator : public StmtExprMutator { n->alloc_buffers = padder.PadSharedMemory(std::move(n->alloc_buffers)); return block; } - ICHECK_EQ(block->writes.size(), 1); - ICHECK_GE(block->reads.size(), 1); + TVM_FFI_ICHECK_EQ(block->writes.size(), 1); + TVM_FFI_ICHECK_GE(block->reads.size(), 1); BufferRegion target_read = block->reads[0]; if (block->reads.size() > 1) { @@ -682,7 +682,7 @@ class AutoCopyMutator : public StmtExprMutator { target_read = block->reads[i]; } } - ICHECK(found) << "Multiple buffer read"; + TVM_FFI_ICHECK(found) << "Multiple buffer read"; } int data_bits = target_read->buffer->dtype.bits(); diff --git a/src/s_tir/transform/memhammer_tensorcore_rewrite.cc b/src/s_tir/transform/memhammer_tensorcore_rewrite.cc index 2285b3843618..09776f8a0624 100644 --- a/src/s_tir/transform/memhammer_tensorcore_rewrite.cc +++ b/src/s_tir/transform/memhammer_tensorcore_rewrite.cc @@ -223,7 +223,7 @@ Stmt RewriteWmmaStore(Stmt stmt) { PostOrderVisit(buf_store->value, [&](const ObjectRef& obj) { const BufferLoadNode* load = obj.as(); if (load && load->buffer.scope() == "wmma.accumulator") { - ICHECK(buf_load == nullptr || buf_load->buffer.same_as(load->buffer)) + TVM_FFI_ICHECK(buf_load == nullptr || buf_load->buffer.same_as(load->buffer)) << "More than one source buffer of wmma accumulator found"; buf_load = load; } @@ -322,7 +322,7 @@ class WmmaToGlobalRewriter : public StmtExprMutator { private: Stmt VisitStmt_(const SeqStmtNode* op) final { if (op == tgt_stmt_) { - ICHECK_EQ(op->seq.size(), 2); + TVM_FFI_ICHECK_EQ(op->seq.size(), 2); Stmt wmma_to_shared = RewriteWmmaStore(op->seq[0]); Stmt shared_to_global = CoalescedAccess().Rewrite(op->seq[1], constraints_, nullptr); return SeqStmt({wmma_to_shared, shared_to_global}); @@ -435,7 +435,7 @@ Stmt RewriteMmaStore(Stmt stmt) { PostOrderVisit(buf_store->value, [&](const ObjectRef& obj) { const BufferLoadNode* load = obj.as(); if (load && load->buffer.scope() == "m16n8k8.matrixC") { - ICHECK(buf_load == nullptr || buf_load->buffer.same_as(load->buffer)) + TVM_FFI_ICHECK(buf_load == nullptr || buf_load->buffer.same_as(load->buffer)) << "More than one source buffer of mma accumulator found"; buf_load = load; } @@ -530,7 +530,7 @@ class MmaToGlobalRewriter : public StmtExprMutator { private: Stmt VisitStmt_(const SeqStmtNode* op) final { if (op == tgt_stmt_) { - ICHECK_EQ(op->seq.size(), 2); + TVM_FFI_ICHECK_EQ(op->seq.size(), 2); // Rewrite for local to shared.dyn // In this rewrite, we store local matrixC buffer to corresponding place in shared memory Stmt mma_to_shared = RewriteMmaStore(op->seq[0]); diff --git a/src/s_tir/transform/merge_shared_memory_allocations.cc b/src/s_tir/transform/merge_shared_memory_allocations.cc index 64e0a9662294..6463f4312507 100644 --- a/src/s_tir/transform/merge_shared_memory_allocations.cc +++ b/src/s_tir/transform/merge_shared_memory_allocations.cc @@ -127,7 +127,7 @@ class SharedMemLinearAccessPatternFinder final : public StmtExprVisitor { const VarNode* buf = op->buffer->data.get(); auto it = alloc_info_.find(buf); if (it != alloc_info_.end() && it->second.alloc) { - ICHECK_LT(it->second.level, scope_.size()); + TVM_FFI_ICHECK_LT(it->second.level, scope_.size()); if (IsAppropriateSharedMemory(ffi::GetRef(buf))) { scope_[it->second.level].touched.push_back(buf); } @@ -158,7 +158,8 @@ class SharedMemLinearAccessPatternFinder final : public StmtExprVisitor { const VarNode* buf = op->buffer->data.get(); auto it = alloc_info_.find(buf); if (it != alloc_info_.end() && it->second.alloc) { - ICHECK_LT(it->second.level, scope_.size()) << "Load memory in places other than store."; + TVM_FFI_ICHECK_LT(it->second.level, scope_.size()) + << "Load memory in places other than store."; if (IsAppropriateSharedMemory(ffi::GetRef(buf))) { scope_[it->second.level].touched.push_back(buf); } @@ -180,7 +181,7 @@ class SharedMemLinearAccessPatternFinder final : public StmtExprVisitor { // Directly reference to the variable count as a read. auto it = alloc_info_.find(buf); if (it != alloc_info_.end() && it->second.alloc) { - ICHECK_LT(it->second.level, scope_.size()); + TVM_FFI_ICHECK_LT(it->second.level, scope_.size()); if (IsAppropriateSharedMemory(ffi::GetRef(buf))) { scope_[it->second.level].touched.push_back(buf); } @@ -200,11 +201,11 @@ class SharedMemLinearAccessPatternFinder final : public StmtExprVisitor { e.touched = std::move(scope_.back().touched); scope_.pop_back(); int64_t end_index = static_cast(linear_seq_.size()); - ICHECK_GT(end_index, begin_index); + TVM_FFI_ICHECK_GT(end_index, begin_index); e.scope_pair_offset = begin_index - end_index; linear_seq_.push_back(e); // record the pointer to end index. - ICHECK_NE(end_index, 0U); + TVM_FFI_ICHECK_NE(end_index, 0U); linear_seq_[begin_index].scope_pair_offset = end_index - begin_index; } @@ -351,7 +352,7 @@ class SharedMemoryRewriter : public StmtExprMutator { template Node VisitBufferAccess(Node node) { if (IsAppropriateSharedMemory(node->buffer->data)) { - ICHECK_EQ(node->indices.size(), 1) + TVM_FFI_ICHECK_EQ(node->indices.size(), 1) << "MergeSharedMemoryAllocations expects flat memory buffers, " << "and is to be run after " << "FlattenBuffer"; @@ -374,7 +375,7 @@ class SharedMemoryRewriter : public StmtExprMutator { } if (IsAppropriateSharedMemory(buffer->data)) { - ICHECK_EQ(buffer->shape.size(), 1) + TVM_FFI_ICHECK_EQ(buffer->shape.size(), 1) << "Buffer " << buffer << " has shape " << buffer->shape << ". " << "MergeSharedMemoryAllocations expects flat memory buffers, " << "and is to be run after " @@ -389,7 +390,7 @@ class SharedMemoryRewriter : public StmtExprMutator { PrimExpr VisitExpr_(const CallNode* op) final { if (op->op.same_as(builtin::tvm_access_ptr())) { - ICHECK_EQ(op->args.size(), 5U); + TVM_FFI_ICHECK_EQ(op->args.size(), 5U); DataType dtype = op->args[0].dtype(); Var buffer = Downcast(op->args[1]); if (!IsAppropriateSharedMemory(buffer)) { @@ -402,7 +403,7 @@ class SharedMemoryRewriter : public StmtExprMutator { return Call(op->dtype, op->op, {op->args[0], merged_buf_var_, extra_offset + offset, extent, op->args[4]}); } else if (op->op.same_as(builtin::ptx_cp_async())) { - ICHECK((op->args.size() == 5U) || (op->args.size() == 6U)); + TVM_FFI_ICHECK((op->args.size() == 5U) || (op->args.size() == 6U)); DataType dtype = op->dtype; Var buffer = Downcast(op->args[0]); if (!IsAppropriateSharedMemory(buffer)) { @@ -429,7 +430,7 @@ class SharedMemoryRewriter : public StmtExprMutator { PrimExpr GetBufferOffset(Var buffer_var, DataType dtype) { auto it = buffer_byte_offsets_.find(buffer_var.get()); - ICHECK(it != buffer_byte_offsets_.end()); + TVM_FFI_ICHECK(it != buffer_byte_offsets_.end()); return indexdiv(it->second, dtype.bytes()); } @@ -519,7 +520,7 @@ class SharedMemoryRewriter : public StmtExprMutator { // In both cases, we need to handle the gen event correctly if (it != event_map_.end() && seq[i].scope_pair_offset >= 0) { for (const VarNode* var : it->second.gen) { - ICHECK(shmem_allocs_.count(var)); + TVM_FFI_ICHECK(shmem_allocs_.count(var)); const AllocateNode* alloc = shmem_allocs_[var]; StorageEntry* dst_entry = FindAlloc(alloc); alloc_map_[var] = dst_entry; @@ -539,7 +540,7 @@ class SharedMemoryRewriter : public StmtExprMutator { * \return the new storage entry */ StorageEntry* NewAlloc(const AllocateNode* op, size_t const_nbits) { - ICHECK(op != nullptr); + TVM_FFI_ICHECK(op != nullptr); // Re-use not successful, allocate a new buffer. StorageEntry* entry = arena_.make(); entry->allocs.push_back({op->buffer_var.get()}); @@ -552,7 +553,7 @@ class SharedMemoryRewriter : public StmtExprMutator { * \return the storage entry */ StorageEntry* FindAlloc(const AllocateNode* op) { - ICHECK(op != nullptr); + TVM_FFI_ICHECK(op != nullptr); // skip plan for local variable, // compiler can do a better job with register allocation. const uint64_t match_range = 16; @@ -631,9 +632,9 @@ class SharedMemoryRewriter : public StmtExprMutator { */ void Free(const VarNode* var) { auto it = alloc_map_.find(var); - ICHECK(it != alloc_map_.end()); + TVM_FFI_ICHECK(it != alloc_map_.end()); StorageEntry* e = it->second; - ICHECK_NE(e->allocs.size(), 0U); + TVM_FFI_ICHECK_NE(e->allocs.size(), 0U); // disable reuse of small arrays if (e->const_nbits > 0 && e->const_nbits <= 32) return; diff --git a/src/s_tir/transform/plan_update_buffer_allocation_location.cc b/src/s_tir/transform/plan_update_buffer_allocation_location.cc index 528f43bade77..3b66230a2cad 100644 --- a/src/s_tir/transform/plan_update_buffer_allocation_location.cc +++ b/src/s_tir/transform/plan_update_buffer_allocation_location.cc @@ -162,7 +162,7 @@ class BufferAllocationLocator : public StmtExprMutator { } Stmt VisitStmt_(const SBlockNode* op) final { - ICHECK(!op->init.defined()); + TVM_FFI_ICHECK(!op->init.defined()); ffi::Array alloc_buffers; auto it = alloc_buffers_.find(op); if (it != alloc_buffers_.end()) { @@ -174,12 +174,12 @@ class BufferAllocationLocator : public StmtExprMutator { for (const MatchBufferRegion match_buffer : op->match_buffers) { const Var& target_var = match_buffer->buffer->data; const Var& source_var = match_buffer->source->buffer->data; - ICHECK(buffer_data_to_buffer_.count(source_var)); + TVM_FFI_ICHECK(buffer_data_to_buffer_.count(source_var)); buffer_data_to_buffer_.Set(target_var, match_buffer->buffer); } Stmt stmt = StmtMutator::VisitStmt_(op); op = stmt.as(); - ICHECK(op != nullptr); + TVM_FFI_ICHECK(op != nullptr); // No longer consider buffers created by match_buffer inside the block when updating access // region. @@ -203,7 +203,7 @@ class BufferAllocationLocator : public StmtExprMutator { } Stmt InjectOpaqueBlock(Stmt body, const ffi::Array& alloc_buffers) { - ICHECK(!alloc_buffers.empty()); + TVM_FFI_ICHECK(!alloc_buffers.empty()); SBlock opaque_block(/*iter_vars=*/{}, /*reads=*/{}, /*writes=*/{}, diff --git a/src/s_tir/transform/remove_store_undef.cc b/src/s_tir/transform/remove_store_undef.cc index ebed134d2f3e..54c91fb7f016 100644 --- a/src/s_tir/transform/remove_store_undef.cc +++ b/src/s_tir/transform/remove_store_undef.cc @@ -51,7 +51,7 @@ class StoreUndefLocator : public StmtExprVisitor { StmtExprVisitor::VisitExpr(op->value); std::swap(has_undef_, stash_undef); if (stash_undef) { - ICHECK(SideEffect(op->value) <= CallEffectKind::kReadState) + TVM_FFI_ICHECK(SideEffect(op->value) <= CallEffectKind::kReadState) << "Error: T.undef() used in BufferStore expressions " << "must not have other side effects"; undef_stores_.insert(op); @@ -71,7 +71,7 @@ class StoreUndefLocator : public StmtExprVisitor { StmtExprVisitor::VisitExpr(op->value); std::swap(has_undef_, stash_undef); if (stash_undef) { - ICHECK(SideEffect(op->value) <= CallEffectKind::kReadState) + TVM_FFI_ICHECK(SideEffect(op->value) <= CallEffectKind::kReadState) << "Error: T.undef() used in Let expressions " << "must not have other side effects"; var_bindings_with_undef_.insert(op->var.get()); @@ -158,11 +158,12 @@ Pass RemoveStoreUndefInternal() { Pass ValidateAllUndefRemoved() { auto pass_func = [](PrimFunc f, IRModule m, tvm::transform::PassContext ctx) { bool contains_undef = ContainsUndefChecker::Check(f->body); - ICHECK(!contains_undef) << "Expected removal of BufferStore containing builtin::undef() " - << "to remove all instances of builtin::undef(). " - << "Instead, result was" - << "\n" - << f; + TVM_FFI_ICHECK(!contains_undef) + << "Expected removal of BufferStore containing builtin::undef() " + << "to remove all instances of builtin::undef(). " + << "Instead, result was" + << "\n" + << f; return f; }; return CreatePrimFuncPass(pass_func, 0, "s_tir.ValidateAllUndefRemoved", {}); diff --git a/src/s_tir/transform/remove_weight_layout_rewrite_block.cc b/src/s_tir/transform/remove_weight_layout_rewrite_block.cc index f9d5a5e6621c..609733f05731 100644 --- a/src/s_tir/transform/remove_weight_layout_rewrite_block.cc +++ b/src/s_tir/transform/remove_weight_layout_rewrite_block.cc @@ -72,16 +72,16 @@ class RemoveLayoutRewriteBlock : public StmtMutator { } // Step 0. Checking block attrs - ICHECK(block->alloc_buffers.empty()); - ICHECK(block->match_buffers.empty()); + TVM_FFI_ICHECK(block->alloc_buffers.empty()); + TVM_FFI_ICHECK(block->match_buffers.empty()); // Step 1. Checking the body is a BufferStore const auto* store = block->body.as(); - ICHECK(store); + TVM_FFI_ICHECK(store); // Step 2. Checking the rhs of buffer store is a BufferLoad const auto* load = store->value.as(); - ICHECK(load); + TVM_FFI_ICHECK(load); // Step 3. Update Buffer buf_map_.Set(load->buffer, store->buffer); @@ -95,7 +95,7 @@ class RemoveLayoutRewriteBlock : public StmtMutator { ffi::Array load_indices; for (auto ind : load->indices) { - ICHECK(ind->IsInstance()); + TVM_FFI_ICHECK(ind->IsInstance()); load_indices.push_back(Downcast(ind)); } buffer_var_to_index_map_[load->buffer->data.get()] = IndexMap(load_indices, store->indices); diff --git a/src/s_tir/transform/renew_defs.cc b/src/s_tir/transform/renew_defs.cc index 73158a376fb0..082538de7f87 100644 --- a/src/s_tir/transform/renew_defs.cc +++ b/src/s_tir/transform/renew_defs.cc @@ -37,7 +37,7 @@ using namespace tvm::tir; Var new_var = this->ReDefineVar(op->FIELD); \ Stmt stmt = StmtExprMutator::VisitStmt_(op); \ op = stmt.as(); \ - ICHECK(op != nullptr); \ + TVM_FFI_ICHECK(op != nullptr); \ auto n = ffi::make_object(*op); \ n->FIELD = std::move(new_var); \ return Stmt(n); \ @@ -145,7 +145,7 @@ class RenewDefMutator : public StmtExprMutator { Stmt VisitStmt_(const BufferStoreNode* op) final { Stmt stmt = StmtExprMutator::VisitStmt_(op); op = stmt.as(); - ICHECK(op != nullptr); + TVM_FFI_ICHECK(op != nullptr); Buffer buffer = VisitDeclOrRemapBuffer(op->buffer); if (buffer.same_as(op->buffer)) { return stmt; @@ -159,7 +159,7 @@ class RenewDefMutator : public StmtExprMutator { PrimExpr VisitExpr_(const BufferLoadNode* op) final { PrimExpr expr = StmtExprMutator::VisitExpr_(op); op = expr.as(); - ICHECK(op != nullptr); + TVM_FFI_ICHECK(op != nullptr); Buffer buffer = VisitDeclOrRemapBuffer(op->buffer); if (buffer.same_as(op->buffer)) { return expr; @@ -179,7 +179,7 @@ class RenewDefMutator : public StmtExprMutator { template void AddDefRemap(const T& source, const T& target) { - ICHECK(remap_.count(source) == 0); + TVM_FFI_ICHECK(remap_.count(source) == 0); remap_.Set(source, target); } @@ -188,7 +188,7 @@ class RenewDefMutator : public StmtExprMutator { if (it != remap_.end()) { return Downcast((*it).second); } - ICHECK(define); + TVM_FFI_ICHECK(define); auto redefine_if_is_var = [this](const PrimExpr& expr) -> PrimExpr { auto it = remap_.find(expr); diff --git a/src/s_tir/transform/storage_access.cc b/src/s_tir/transform/storage_access.cc index 873c54b1f066..f3fc337ddba7 100644 --- a/src/s_tir/transform/storage_access.cc +++ b/src/s_tir/transform/storage_access.cc @@ -37,7 +37,7 @@ void StorageAccessVisitor::VisitExpr_(const BufferLoadNode* op) { Var buf = op->buffer->data; StorageScope scope = GetScope(buf); if (Enabled(buf.get(), scope)) { - ICHECK(allow_append_) << op << " " << scope.to_string(); + TVM_FFI_ICHECK(allow_append_) << op << " " << scope.to_string(); AccessEntry e; e.threads = env_threads(); e.buffer = buf; @@ -55,7 +55,7 @@ void StorageAccessVisitor::VisitExpr_(const BufferLoadNode* op) { void StorageAccessVisitor::VisitStmt_(const BufferStoreNode* op) { allow_append_ = true; - ICHECK_EQ(curr_stmt_.access.size(), 0U); + TVM_FFI_ICHECK_EQ(curr_stmt_.access.size(), 0U); curr_stmt_.stmt = op; Var buf = op->buffer->data; @@ -83,7 +83,7 @@ void StorageAccessVisitor::VisitStmt_(const BufferStoreNode* op) { void StorageAccessVisitor::VisitStmt_(const EvaluateNode* op) { allow_append_ = true; - ICHECK_EQ(curr_stmt_.access.size(), 0U); + TVM_FFI_ICHECK_EQ(curr_stmt_.access.size(), 0U); curr_stmt_.stmt = op; StmtExprVisitor::VisitStmt_(op); // push to the scope @@ -96,7 +96,7 @@ void StorageAccessVisitor::VisitStmt_(const EvaluateNode* op) { void StorageAccessVisitor::VisitStmt_(const LetStmtNode* op) { allow_append_ = true; - ICHECK_EQ(curr_stmt_.access.size(), 0U); + TVM_FFI_ICHECK_EQ(curr_stmt_.access.size(), 0U); curr_stmt_.stmt = op; this->VisitExpr(op->value); // push to the scope @@ -110,7 +110,7 @@ void StorageAccessVisitor::VisitStmt_(const LetStmtNode* op) { void StorageAccessVisitor::VisitStmt_(const AttrStmtNode* op) { if (op->attr_key == tir::attr::double_buffer_write) { - ICHECK(double_buffer_write_ == nullptr); + TVM_FFI_ICHECK(double_buffer_write_ == nullptr); double_buffer_write_ = op->node.as(); scope_.push_back(std::vector()); StmtExprVisitor::VisitStmt_(op); @@ -170,7 +170,7 @@ void StorageAccessVisitor::VisitStmt_(const ForNode* op) { arith::IntSet::FromRange(Range::FromMinExtent(op->min, op->extent)); for (AccessEntry& e : s.access) { if (e.buffer.defined()) { - ICHECK(e.touched.size()); + TVM_FFI_ICHECK(e.touched.size()); ffi::Array new_touched; for (const auto& touched : e.touched) { new_touched.push_back(arith::EvalSet(touched, relax_map)); @@ -244,7 +244,7 @@ void StorageAccessVisitor::VisitExpr_(const CallNode* op) { const BufferLoadNode* load = op->args[0].as(); StmtExprVisitor::VisitExpr_(load); } else if (op->op.same_as(builtin::tvm_access_ptr())) { - ICHECK_EQ(op->args.size(), 5U); + TVM_FFI_ICHECK_EQ(op->args.size(), 5U); DataType dtype = op->args[0].dtype(); const VarNode* buffer = op->args[1].as(); PrimExpr offset = op->args[2]; @@ -253,7 +253,7 @@ void StorageAccessVisitor::VisitExpr_(const CallNode* op) { StorageScope scope = GetScope(ffi::GetRef(buffer)); // The buffer scope. if (Enabled(buffer, scope)) { - ICHECK(allow_append_); + TVM_FFI_ICHECK(allow_append_); AccessEntry e; e.threads = env_threads(); e.dtype = dtype; @@ -271,7 +271,7 @@ void StorageAccessVisitor::VisitExpr_(const CallNode* op) { } StmtExprVisitor::VisitExpr_(op); } else if (op->op.same_as(builtin::tvm_storage_sync())) { - ICHECK(allow_append_); + TVM_FFI_ICHECK(allow_append_); const std::string& s = op->args[0].as()->value; if (s != "warp") { StorageScope scope = StorageScope::Create(s); diff --git a/src/s_tir/transform/tensorcore_infer_fragment.cc b/src/s_tir/transform/tensorcore_infer_fragment.cc index d1232e51645d..428f8f6f54fe 100644 --- a/src/s_tir/transform/tensorcore_infer_fragment.cc +++ b/src/s_tir/transform/tensorcore_infer_fragment.cc @@ -47,28 +47,28 @@ class FragmentGetter : public StmtExprVisitor { if (op->op.same_as(builtin::tvm_load_matrix_sync()) || op->op.same_as(builtin::tvm_store_matrix_sync())) { // Get shape and layout information from load and store intrinsic - ICHECK_EQ(op->args.size(), 8U); + TVM_FFI_ICHECK_EQ(op->args.size(), 8U); const VarNode* buffer_var = op->args[0].as(); - ICHECK(buffer_var); + TVM_FFI_ICHECK(buffer_var); // Get shape const IntImmNode* m = op->args[1].as(); const IntImmNode* n = op->args[2].as(); const IntImmNode* k = op->args[3].as(); const StringImmNode* layout = op->args[7].as(); - ICHECK(m); - ICHECK(n); - ICHECK(k); - ICHECK(layout); + TVM_FFI_ICHECK(m); + TVM_FFI_ICHECK(n); + TVM_FFI_ICHECK(k); + TVM_FFI_ICHECK(layout); std::string scope = GetPtrStorageScope(ffi::GetRef(buffer_var)); if (fragments.count(buffer_var)) { // check if the fragment has met before FragmentInfo info = fragments[buffer_var]; - ICHECK_EQ(m->value, info.m); - ICHECK_EQ(n->value, info.n); - ICHECK_EQ(k->value, info.k); + TVM_FFI_ICHECK_EQ(m->value, info.m); + TVM_FFI_ICHECK_EQ(n->value, info.n); + TVM_FFI_ICHECK_EQ(k->value, info.k); if (scope == "wmma.matrix_a" || scope == "wmma.matrix_b") { - ICHECK_EQ(layout->value, info.layout); + TVM_FFI_ICHECK_EQ(layout->value, info.layout); } } else { // store metadata @@ -82,23 +82,23 @@ class FragmentGetter : public StmtExprVisitor { } } else if (op->op.same_as(builtin::tvm_fill_fragment())) { // Get shape information from fill intrinsic - ICHECK_EQ(op->args.size(), 6U); + TVM_FFI_ICHECK_EQ(op->args.size(), 6U); const VarNode* buffer_var = op->args[0].as(); - ICHECK(buffer_var); + TVM_FFI_ICHECK(buffer_var); // Get shape const IntImmNode* m = op->args[1].as(); const IntImmNode* n = op->args[2].as(); const IntImmNode* k = op->args[3].as(); - ICHECK(m); - ICHECK(n); - ICHECK(k); + TVM_FFI_ICHECK(m); + TVM_FFI_ICHECK(n); + TVM_FFI_ICHECK(k); std::string scope = GetPtrStorageScope(ffi::GetRef(buffer_var)); if (fragments.count(buffer_var)) { FragmentInfo info = fragments[buffer_var]; - ICHECK_EQ(m->value, info.m); - ICHECK_EQ(n->value, info.n); - ICHECK_EQ(k->value, info.k); + TVM_FFI_ICHECK_EQ(m->value, info.m); + TVM_FFI_ICHECK_EQ(n->value, info.n); + TVM_FFI_ICHECK_EQ(k->value, info.k); } else { // default to row major ordering FragmentInfo info(m->value, n->value, k->value, "row_major", scope); @@ -135,31 +135,31 @@ class FragmentChecker : public StmtExprVisitor { StmtExprVisitor::VisitExpr_(op); // Check shape when calling tvm_mma_sync if (op->op.same_as(builtin::tvm_mma_sync()) || op->op.same_as(builtin::tvm_bmma_sync())) { - ICHECK_EQ(op->args.size(), 8U); + TVM_FFI_ICHECK_EQ(op->args.size(), 8U); const VarNode* buffer_var_d = op->args[0].as(); const VarNode* buffer_var_a = op->args[2].as(); const VarNode* buffer_var_b = op->args[4].as(); const VarNode* buffer_var_c = op->args[6].as(); - ICHECK(buffer_var_d); - ICHECK(buffer_var_a); - ICHECK(buffer_var_b); - ICHECK(buffer_var_c); + TVM_FFI_ICHECK(buffer_var_d); + TVM_FFI_ICHECK(buffer_var_a); + TVM_FFI_ICHECK(buffer_var_b); + TVM_FFI_ICHECK(buffer_var_c); // Check all fragment A, B, C and D have the same shape - ICHECK(CheckShape(buffer_var_d, buffer_var_a)); - ICHECK(CheckShape(buffer_var_d, buffer_var_b)); - ICHECK(CheckShape(buffer_var_d, buffer_var_c)); + TVM_FFI_ICHECK(CheckShape(buffer_var_d, buffer_var_a)); + TVM_FFI_ICHECK(CheckShape(buffer_var_d, buffer_var_b)); + TVM_FFI_ICHECK(CheckShape(buffer_var_d, buffer_var_c)); } } private: // A tool for checking shapes of two fragments bool CheckShape(const VarNode* buffer1, const VarNode* buffer2) { - CHECK(fragment_getter.fragments.count(buffer1)) + TVM_FFI_ICHECK(fragment_getter.fragments.count(buffer1)) << "Tensorecore fragment " << buffer1->name_hint << " must be filled (with tvm_fill_fragment) or loaded (with tvm_load_matrix_sync) before " "use."; - CHECK(fragment_getter.fragments.count(buffer2)) + TVM_FFI_ICHECK(fragment_getter.fragments.count(buffer2)) << "Tensorecore fragment " << buffer2->name_hint << " must be filled (with tvm_fill_fragment) or loaded (with tvm_load_matrix_sync) before " "use."; diff --git a/src/s_tir/transform/thread_storage_sync.cc b/src/s_tir/transform/thread_storage_sync.cc index 57d8f25b5125..6892f83a2925 100644 --- a/src/s_tir/transform/thread_storage_sync.cc +++ b/src/s_tir/transform/thread_storage_sync.cc @@ -115,7 +115,7 @@ class ThreadSyncPlanner : public StorageAccessVisitor { } } if (sync_before_stmt) { - ICHECK_EQ(condition_counter(), 0) << "Cannot insert syncs inside condition"; + TVM_FFI_ICHECK_EQ(condition_counter(), 0) << "Cannot insert syncs inside condition"; syncs_inserted_.insert(s.stmt); } } @@ -142,7 +142,7 @@ class ThreadSyncPlanner : public StorageAccessVisitor { } } if (sync_before_stmt) { - ICHECK_EQ(condition_counter(), 0) << "Cannot insert syncs inside condition"; + TVM_FFI_ICHECK_EQ(condition_counter(), 0) << "Cannot insert syncs inside condition"; syncs_inserted_.insert(s.stmt); break; } @@ -296,7 +296,7 @@ class ThreadSyncAfterWaitQueueInserter : public StmtExprMutator { auto sync = Evaluate(Call(DataType::Int(32), builtin::tvm_storage_sync(), {StringImm(sync_scope_.to_string())})); auto inner = op->body.as(); - ICHECK(inner && inner->attr_key == tir::attr::async_wait_inflight_count); + TVM_FFI_ICHECK(inner && inner->attr_key == tir::attr::async_wait_inflight_count); auto zero = make_zero(DataType::Int(32)); auto new_body = SeqStmt({sync, inner->body}); return AttrStmt(zero, tir::attr::async_wait_queue_scope, op->value, @@ -370,7 +370,7 @@ class ThreadSyncInserter : public StmtExprMutator { if (op->op.same_as(builtin::tvm_access_ptr())) { PrimExpr expr = StmtExprMutator::VisitExpr_(op); op = expr.as(); - ICHECK_EQ(op->args.size(), 5U); + TVM_FFI_ICHECK_EQ(op->args.size(), 5U); Var buffer_var(Downcast(op->args[1])); const IntImmNode* flag = op->args[4].as(); if ((flag->value & 1) && sync_scope_.rank == StorageRank::kGlobal && @@ -401,7 +401,7 @@ class ThreadSyncInserter : public StmtExprMutator { // private functions. Stmt InitGlobalBarrier(const AttrStmtNode* op) { - ICHECK(op != nullptr); + TVM_FFI_ICHECK(op != nullptr); ffi::Array pargs = {StringImm(runtime::symbol::tvm_prepare_global_barrier)}; Stmt prep = Evaluate(Call(DataType::Int(32), builtin::tvm_call_packed(), pargs)); Stmt body = op->body; @@ -418,9 +418,9 @@ class ThreadSyncInserter : public StmtExprMutator { return SeqStmt({prep, body}); } Stmt MakeGlobalBarrier() { - ICHECK(sync_scope_.rank == StorageRank::kGlobal); + TVM_FFI_ICHECK(sync_scope_.rank == StorageRank::kGlobal); if (!num_blocks_.defined()) { - ICHECK(!is_lead_.defined()); + TVM_FFI_ICHECK(!is_lead_.defined()); num_work_dim_ = thread_extents_.size(); for (const AttrStmtNode* attr : thread_extents_) { IterVar iv = Downcast(attr->node); @@ -433,7 +433,7 @@ class ThreadSyncInserter : public StmtExprMutator { } } } else { - ICHECK_EQ(num_work_dim_, thread_extents_.size()); + TVM_FFI_ICHECK_EQ(num_work_dim_, thread_extents_.size()); } return Evaluate(Call(DataType::Int(32), builtin::tvm_storage_sync(), {StringImm(sync_scope_.to_string()), is_lead_, num_blocks_})); diff --git a/src/s_tir/transform/transform_mma_buffer_layout.cc b/src/s_tir/transform/transform_mma_buffer_layout.cc index e92f24738a47..1437a89f93dd 100644 --- a/src/s_tir/transform/transform_mma_buffer_layout.cc +++ b/src/s_tir/transform/transform_mma_buffer_layout.cc @@ -56,11 +56,11 @@ class MmaBufferLayoutTransformer : public StmtExprMutator { // m16n8k8.matrixC // bi = 16, bj = 8 size_t size = buffer->shape.size(); - ICHECK_GE(size, 2); + TVM_FFI_ICHECK_GE(size, 2); const IntImmNode* dim0 = buffer->shape[size - 2].as(); const IntImmNode* dim1 = buffer->shape[size - 1].as(); - ICHECK(dim0 != nullptr && dim1 != nullptr); - ICHECK(dim0->value % 16 == 0 && dim1->value % 8 == 0); + TVM_FFI_ICHECK(dim0 != nullptr && dim1 != nullptr); + TVM_FFI_ICHECK(dim0->value % 16 == 0 && dim1->value % 8 == 0); std::vector new_shape; for (size_t i = 0; i < size - 2; ++i) { @@ -79,11 +79,11 @@ class MmaBufferLayoutTransformer : public StmtExprMutator { // m16n8k8.matrixA // bi = 32, bj = 8 size_t size = buffer->shape.size(); - ICHECK_GE(size, 2); + TVM_FFI_ICHECK_GE(size, 2); const IntImmNode* dim0 = buffer->shape[size - 2].as(); const IntImmNode* dim1 = buffer->shape[size - 1].as(); - ICHECK(dim0 != nullptr && dim1 != nullptr); - ICHECK(dim0->value % 32 == 0 && dim1->value % 8 == 0); + TVM_FFI_ICHECK(dim0 != nullptr && dim1 != nullptr); + TVM_FFI_ICHECK(dim0->value % 32 == 0 && dim1->value % 8 == 0); std::vector new_shape; for (size_t i = 0; i < size - 2; ++i) { new_shape.push_back(buffer->shape[i]); @@ -101,11 +101,11 @@ class MmaBufferLayoutTransformer : public StmtExprMutator { // m16n8k8.matrixB // bj = 8, bj = 32 size_t size = buffer->shape.size(); - ICHECK_GE(size, 2); + TVM_FFI_ICHECK_GE(size, 2); const IntImmNode* dim0 = buffer->shape[size - 2].as(); const IntImmNode* dim1 = buffer->shape[size - 1].as(); - ICHECK(dim0 != nullptr && dim1 != nullptr); - ICHECK(dim0->value % 8 == 0 && dim1->value % 32 == 0); + TVM_FFI_ICHECK(dim0 != nullptr && dim1 != nullptr); + TVM_FFI_ICHECK(dim0->value % 8 == 0 && dim1->value % 32 == 0); std::vector new_shape; for (size_t i = 0; i < size - 2; ++i) { new_shape.push_back(buffer->shape[i]); @@ -132,7 +132,7 @@ class MmaBufferLayoutTransformer : public StmtExprMutator { auto* n = store.CopyOnWrite(); if (store->buffer.scope() == "m16n8k8.matrixC") { const auto index_map_func = tvm::ffi::Function::GetGlobal("tir.index_map_m16n8k8.matrixC"); - ICHECK(index_map_func.has_value()); + TVM_FFI_ICHECK(index_map_func.has_value()); auto index_map = IndexMap::FromFunc(2, *index_map_func); auto new_indices = index_map->MapIndices(store->indices, &analyzer); n->buffer = buffer_map_[store->buffer]; @@ -151,7 +151,7 @@ class MmaBufferLayoutTransformer : public StmtExprMutator { auto* n = load.CopyOnWrite(); if (load->buffer.scope() == "m16n8k8.matrixC") { const auto index_map_func = tvm::ffi::Function::GetGlobal("tir.index_map_m16n8k8.matrixC"); - ICHECK(index_map_func.has_value()); + TVM_FFI_ICHECK(index_map_func.has_value()); auto index_map = IndexMap::FromFunc(2, *index_map_func); auto new_indices = index_map->MapIndices(load->indices, &analyzer); n->buffer = buffer_map_[load->buffer]; diff --git a/src/s_tir/transform/unify_thread_binding.cc b/src/s_tir/transform/unify_thread_binding.cc index c43e386e7143..f9b2d131cd95 100644 --- a/src/s_tir/transform/unify_thread_binding.cc +++ b/src/s_tir/transform/unify_thread_binding.cc @@ -112,9 +112,9 @@ class ThreadBindingUnifier : public StmtExprMutator { ffi::Map::iterator it = thread_tag2iter_var_map_.find(thread_tag); if (it != thread_tag2iter_var_map_.end()) { new_iter_var = (*it).second; - ICHECK(ana.CanProveEqual(dom->min, new_iter_var->dom->min)); - CHECK(ana.CanProveEqual(dom->extent, new_iter_var->dom->extent)) - << "ValueError: All loops that are bound to `" << thread_tag + TVM_FFI_ICHECK(ana.CanProveEqual(dom->min, new_iter_var->dom->min)); + TVM_FFI_CHECK(ana.CanProveEqual(dom->extent, new_iter_var->dom->extent), ValueError) + << "All loops that are bound to `" << thread_tag << "` should have the same extent. However, there are two loops with extent " << new_iter_var->dom->extent << " and " << dom->extent << ", which are not equal"; } else { @@ -134,7 +134,7 @@ class ThreadBindingUnifier : public StmtExprMutator { // binding of the kernel. Stmt new_stmt = StmtMutator::VisitStmt_(op); auto* new_node = new_stmt.as(); - ICHECK(new_node); + TVM_FFI_ICHECK(new_node); thread_block_depth_ = old_thread_block_depth; if (is_kernel_launch_scope) { return EmitLaunchThreads(new_node->body); diff --git a/src/s_tir/transform/using_assume_to_reduce_branches.cc b/src/s_tir/transform/using_assume_to_reduce_branches.cc index ab4ec76afef1..2c356c8f8efa 100644 --- a/src/s_tir/transform/using_assume_to_reduce_branches.cc +++ b/src/s_tir/transform/using_assume_to_reduce_branches.cc @@ -155,7 +155,7 @@ class ParseAssumeAndOvercompute : public IRMutatorWithAnalyzer { } ~InternalConstraintContext() { - ICHECK_EQ(self->conditions_.size(), new_num_constraints) + TVM_FFI_ICHECK_EQ(self->conditions_.size(), new_num_constraints) << "Internal error: Each condition should only be popped once."; self->conditions_.erase(self->conditions_.begin() + old_num_constraints, self->conditions_.end()); @@ -296,18 +296,21 @@ class ParseAssumeAndOvercompute : public IRMutatorWithAnalyzer { } else if (side_effect == tir::CallEffectKind::kReadState) { buffer_exprs.push_back(expr); } else { - LOG(FATAL) << "Assumption must be pure or read-only, but contained expression " << expr - << " with side-effect \'" << side_effect << "\'"; + TVM_FFI_THROW(InternalError) + << "Assumption must be pure or read-only, but contained expression " << expr + << " with side-effect \'" << side_effect << "\'"; } } additional_predicate = analyzer_->Simplify(std::move(additional_predicate)); - CHECK_EQ(buffer_exprs.size(), 1) << "T.assume must contain only a single buffer expression"; + TVM_FFI_ICHECK_EQ(buffer_exprs.size(), 1) + << "T.assume must contain only a single buffer expression"; auto* as_equal_node = buffer_exprs[0].as(); - CHECK(as_equal_node) << "T.assume buffer constraint must be of the form 'buffer[indices] == " - "value', but received " - << assumption; + TVM_FFI_ICHECK(as_equal_node) + << "T.assume buffer constraint must be of the form 'buffer[indices] == " + "value', but received " + << assumption; if (!as_equal_node) { // This assumption is an inequality on a data-dependent // conditional. Not an error for this to occur, but also not @@ -326,7 +329,8 @@ class ParseAssumeAndOvercompute : public IRMutatorWithAnalyzer { load = opt.value(); value = as_equal_node->a; } else { - LOG(FATAL) << "T.assume buffer constraint must be of the form 'buffer[indices] == value'"; + TVM_FFI_THROW(InternalError) + << "T.assume buffer constraint must be of the form 'buffer[indices] == value'"; } // Populating the assume statement predicate, buffer, value @@ -342,8 +346,8 @@ class ParseAssumeAndOvercompute : public IRMutatorWithAnalyzer { map_buffer_assumption[buf_data.buffer_load->buffer] = buf_data; auto has_side_effect = tir::SideEffect(value) > tir::CallEffectKind::kPure; - CHECK(!has_side_effect) << "Buffer value in constraint must be pure expression, but was " - << value; + TVM_FFI_ICHECK(!has_side_effect) + << "Buffer value in constraint must be pure expression, but was " << value; if (has_side_effect) { return; } diff --git a/src/script/ir_builder/base.cc b/src/script/ir_builder/base.cc index 658e76be466c..6cd40b092249 100644 --- a/src/script/ir_builder/base.cc +++ b/src/script/ir_builder/base.cc @@ -44,7 +44,7 @@ void IRBuilderFrameNode::ExitWithScope() { void IRBuilderFrameNode::AddCallback(ffi::TypedFunction callback) { if (IRBuilder::Current()->frames.empty()) { - LOG(FATAL) << "ValueError: No frames in Builder to add callback"; + TVM_FFI_THROW(ValueError) << "No frames in Builder to add callback"; } IRBuilder::Current()->frames.back()->callbacks.push_back(callback); } @@ -63,9 +63,9 @@ std::vector* ThreadLocalBuilderStack() { void IRBuilder::EnterWithScope() { IRBuilderNode* n = this->get(); - CHECK(n->frames.empty()) << "ValueError: There are frame(s) left in the builder: " - << n->frames.size() - << ". Please use a fresh new builder every time building IRs"; + TVM_FFI_CHECK(n->frames.empty(), ValueError) + << "There are frame(s) left in the builder: " << n->frames.size() + << ". Please use a fresh new builder every time building IRs"; n->result = std::nullopt; std::vector* stack = ThreadLocalBuilderStack(); stack->push_back(*this); @@ -73,13 +73,13 @@ void IRBuilder::EnterWithScope() { void IRBuilder::ExitWithScope() { std::vector* stack = ThreadLocalBuilderStack(); - ICHECK(!stack->empty()); + TVM_FFI_ICHECK(!stack->empty()); stack->pop_back(); } IRBuilder IRBuilder::Current() { std::vector* stack = ThreadLocalBuilderStack(); - CHECK(!stack->empty()) << "ValueError: No builder in current scope"; + TVM_FFI_CHECK(!stack->empty(), ValueError) << "No builder in current scope"; return stack->back(); } @@ -97,9 +97,9 @@ Namer::FType& Namer::vtable() { void Namer::Name(ObjectRef node, ffi::String name) { static const FType& f = vtable(); - CHECK(node.defined()) << "ValueError: Cannot name nullptr with: " << name; - CHECK(f.can_dispatch(node)) << "ValueError: Do not know how to name type \"" - << node->GetTypeKey(); + TVM_FFI_CHECK(node.defined(), ValueError) << "Cannot name nullptr with: " << name; + TVM_FFI_CHECK(f.can_dispatch(node), ValueError) + << "Do not know how to name type \"" << node->GetTypeKey(); f(node, name); } diff --git a/src/script/ir_builder/ir/frame.cc b/src/script/ir_builder/ir/frame.cc index fae4ba41bfda..c7a73bec3b0c 100644 --- a/src/script/ir_builder/ir/frame.cc +++ b/src/script/ir_builder/ir/frame.cc @@ -29,17 +29,17 @@ TVM_FFI_STATIC_INIT_BLOCK() { IRModuleFrameNode::RegisterReflection(); } void IRModuleFrameNode::ExitWithScope() { ffi::Map func_map; - CHECK_EQ(functions.size(), global_var_map.size()) + TVM_FFI_ICHECK_EQ(functions.size(), global_var_map.size()) << "All functions must be defined in the IRModule. Got " << global_var_map.size() << "declared function(s), but only " << functions.size() << "defined function(s)."; for (const auto& kv : functions) { const GlobalVar& gv = kv.first; const BaseFunc& func = kv.second; - CHECK(func.defined()) << "ValueError: function " << gv->name_hint << " is not defined"; + TVM_FFI_CHECK(func.defined(), ValueError) << "function " << gv->name_hint << " is not defined"; func_map.Set(gv, func); } IRBuilder builder = IRBuilder::Current(); - ICHECK(!builder->result.defined()) << "ValueError: Builder.result has already been set"; + TVM_FFI_CHECK(!builder->result.defined(), ValueError) << "Builder.result has already been set"; auto dict_attrs = DictAttrs(attrs); builder->result = tvm::IRModule(func_map, {}, dict_attrs, global_infos); } diff --git a/src/script/ir_builder/ir/ir.cc b/src/script/ir_builder/ir/ir.cc index e609f1b8efd2..cee620d3ddec 100644 --- a/src/script/ir_builder/ir/ir.cc +++ b/src/script/ir_builder/ir/ir.cc @@ -45,14 +45,14 @@ inline relax::StructInfo GetGlobalVarStructInfo(const BaseFunc& func) { return tvm::relax::FuncStructInfo::OpaqueFunc( tvm::relax::StructInfoFromType(prim_func->ret_type)); } else { - LOG(FATAL) << "Unsupported function type: " << func->GetTypeKey(); + TVM_FFI_THROW(InternalError) << "Unsupported function type: " << func->GetTypeKey(); } } GlobalVar DeclFunction(const ffi::String& func_name, const BaseFunc& func_signature) { IRModuleFrame frame = FindModuleFrame(); - CHECK(!frame->global_var_map.count(func_name)) - << "ValueError: function " << func_name << " already exists"; + TVM_FFI_CHECK(!frame->global_var_map.count(func_name), ValueError) + << "function " << func_name << " already exists"; auto gvar_type = [&]() -> Type { if (auto prim_func = func_signature.as()) { @@ -66,8 +66,8 @@ GlobalVar DeclFunction(const ffi::String& func_name, const BaseFunc& func_signat GlobalVar gv = GlobalVar(func_name); gv->struct_info_ = GetGlobalVarStructInfo(func_signature); - CHECK(frame->functions.find(gv) == frame->functions.end()) - << "ValueError: function " << func_name << " has already been defined."; + TVM_FFI_CHECK(frame->functions.find(gv) == frame->functions.end(), ValueError) + << "function " << func_name << " has already been defined."; frame->global_var_map.Set(func_name, gv); frame->functions.Set(gv, func_signature); return gv; @@ -76,8 +76,8 @@ GlobalVar DeclFunction(const ffi::String& func_name, const BaseFunc& func_signat void DefFunction(const ffi::String& func_name, const BaseFunc& func) { IRModuleFrame frame = FindModuleFrame(); auto it = frame->global_var_map.find(func_name); - CHECK(it != frame->global_var_map.end()) - << "ValueError: function " << func_name << " does not exist, please declare it first."; + TVM_FFI_CHECK(it != frame->global_var_map.end(), ValueError) + << "function " << func_name << " does not exist, please declare it first."; const GlobalVar& gv = (*it).second; frame->functions.Set(gv, func); gv->struct_info_ = GetGlobalVarStructInfo(func); @@ -88,7 +88,7 @@ void ModuleAttrs(ffi::Map attrs, bool allow_overwrite) { // TODO(hongyi): add comments to explain why we need to check if the module frame is in scope IRModuleFrame frame = FindModuleFrame("I.ModuleAttr"); if (!allow_overwrite && !frame->attrs.empty()) { - LOG(FATAL) << "ValueError: Duplicate module attrs, previous one is:\n" << frame->attrs; + TVM_FFI_THROW(ValueError) << "Duplicate module attrs, previous one is:\n" << frame->attrs; } frame->attrs = attrs; } @@ -109,7 +109,7 @@ void ModuleSetAttr(const ffi::String& key, const ffi::Optional& value if (IRBuilder::IsInScope()) { IRModuleFrame frame = FindModuleFrame(); if (!allow_override && frame->attrs.find(key) != frame->attrs.end() && value.defined()) { - LOG(FATAL) << "ValueError: Duplicate module attr " << key; + TVM_FFI_THROW(ValueError) << "Duplicate module attr " << key; } if (value.defined()) { frame->attrs.Set(key, value.value()); @@ -117,7 +117,7 @@ void ModuleSetAttr(const ffi::String& key, const ffi::Optional& value frame->attrs.erase(key); } } else { - LOG(FATAL) << "ValueError: Currently in in the scope of a module."; + TVM_FFI_THROW(ValueError) << "Currently in in the scope of a module."; } } @@ -125,8 +125,8 @@ void ModuleGlobalInfos(ffi::Map> global_info if (IRBuilder::IsInScope()) { IRModuleFrame frame = FindModuleFrame("I.ModuleGlobalInfos"); if (!frame->global_infos.empty()) { - LOG(FATAL) << "ValueError: Duplicate module global_infos, previous one is:\n" - << frame->global_infos; + TVM_FFI_THROW(ValueError) << "Duplicate module global_infos, previous one is:\n" + << frame->global_infos; } frame->global_infos = global_infos; } @@ -136,12 +136,12 @@ VDevice LookupVDevice(ffi::String target_kind, int device_index) { if (IRBuilder::IsInScope()) { IRModuleFrame frame = FindModuleFrame(); if (frame->global_infos.empty()) { - LOG(FATAL) << "ValueError: The GlobalInfos in the IRModule is not defined."; + TVM_FFI_THROW(ValueError) << "The GlobalInfos in the IRModule is not defined."; } ffi::Array vdevices = frame->global_infos["vdevice"]; if (vdevices.empty() || device_index < 0 || static_cast(device_index) >= vdevices.size()) { - LOG(FATAL) << "ValueError: The target VDevice in the GlobalInfos was not found."; + TVM_FFI_THROW(ValueError) << "The target VDevice in the GlobalInfos was not found."; } if (target_kind == "vdevice") { return Downcast(vdevices[device_index]); diff --git a/src/script/ir_builder/ir/utils.h b/src/script/ir_builder/ir/utils.h index 54ea6ce6ad92..d53287013148 100644 --- a/src/script/ir_builder/ir/utils.h +++ b/src/script/ir_builder/ir/utils.h @@ -34,10 +34,10 @@ inline IRModuleFrame FindModuleFrame(const ffi::String& method) { return frame.value(); } } else { - LOG(FATAL) << "ValueError: IRModule frame not find. Please ensure '" << method - << "' is called under I.ir_module()"; + TVM_FFI_THROW(ValueError) << "IRModule frame not find. Please ensure '" << method + << "' is called under I.ir_module()"; } - LOG(FATAL) << "ValueError: '" << method << "' must be called immediately under I.ir_module()"; + TVM_FFI_THROW(ValueError) << "'" << method << "' must be called immediately under I.ir_module()"; throw; } @@ -46,8 +46,8 @@ inline IRModuleFrame FindModuleFrame() { if (ffi::Optional frame = builder->FindFrame()) { return frame.value(); } else { - LOG(FATAL) << "ValueError: IRModule frame not find. Please ensure it" - << " is called under I.ir_module()"; + TVM_FFI_THROW(ValueError) << "IRModule frame not find. Please ensure it" + << " is called under I.ir_module()"; } throw; } diff --git a/src/script/ir_builder/relax/distributed.cc b/src/script/ir_builder/relax/distributed.cc index 3efb38d44bf5..a505d317787f 100644 --- a/src/script/ir_builder/relax/distributed.cc +++ b/src/script/ir_builder/relax/distributed.cc @@ -33,9 +33,10 @@ Expr MakeCallTIRDist(Expr func, Tuple args, ffi::Optional packed_ints) { for (const distributed::DTensorStructInfo& sinfo : out_sinfo_list) { const auto* shape = sinfo->tensor_sinfo->shape.as(); - CHECK(shape != nullptr) << "out_sinfo of call_tir should have defined ShapeExpr as shape. " - "However, one given structure info is " - << sinfo; + TVM_FFI_ICHECK(shape != nullptr) + << "out_sinfo of call_tir should have defined ShapeExpr as shape. " + "However, one given structure info is " + << sinfo; } StructInfo out_sinfo{nullptr}; diff --git a/src/script/ir_builder/relax/frame.cc b/src/script/ir_builder/relax/frame.cc index c57ca041b328..4b5274cf248a 100644 --- a/src/script/ir_builder/relax/frame.cc +++ b/src/script/ir_builder/relax/frame.cc @@ -46,8 +46,8 @@ void SeqExprFrameNode::ExitWithScope() { if (ffi::Optional block_frame = IRBuilder::Current()->GetLastFrame()) { block_frame.value()->ExitWithScope(); - ICHECK(!IRBuilder::Current()->GetLastFrame().defined()) - << "ValueError: There is some remaining BindingBlockFrame that is not properly popped out."; + TVM_FFI_CHECK(!IRBuilder::Current()->GetLastFrame().defined(), ValueError) + << "There is some remaining BindingBlockFrame that is not properly popped out."; } RelaxFrameNode::ExitWithScope(); } @@ -68,8 +68,9 @@ void FunctionFrameNode::ExitWithScope() { IRBuilder builder = IRBuilder::Current(); SeqExprFrameNode::ExitWithScope(); // Step 1: Create the function. - CHECK(output.defined()) << "ValueError: A Relax function must have a return value. Please use " - "`return` to return an Expr"; + TVM_FFI_CHECK(output.defined(), ValueError) + << "A Relax function must have a return value. Please use " + "`return` to return an Expr"; Expr body = this->block_builder->Normalize(tvm::relax::SeqExpr(binding_blocks, output.value())); // if the function is not private, add a global symbol to its attributes @@ -86,12 +87,13 @@ void FunctionFrameNode::ExitWithScope() { // Step 2: Update IRModule. if (builder->frames.empty()) { // Case 0. No outer frame, return function directly - ICHECK(!builder->result.defined()) << "ValueError: Builder.result has already been set"; + TVM_FFI_CHECK(!builder->result.defined(), ValueError) << "Builder.result has already been set"; builder->result = func; } else if (ffi::Optional opt_frame = builder->FindFrame()) { // Case 1. A global function of an IRModule - CHECK(name.has_value()) << "ValueError: The function name must be defined before exiting the " - "function scope, if it's defined in a Module"; + TVM_FFI_CHECK(name.has_value(), ValueError) + << "The function name must be defined before exiting the " + "function scope, if it's defined in a Module"; const IRModuleFrame& frame = opt_frame.value(); const ffi::String& func_name = name.value_or(""); if (!frame->global_var_map.count(func_name)) { @@ -102,7 +104,7 @@ void FunctionFrameNode::ExitWithScope() { // Note we do checks to disallow redefinition of functions inside the `DefFunction`. ir::DefFunction(func_name, func); } else { - LOG(FATAL) << "ValueError: Cannot find where to insert Relax.Function"; + TVM_FFI_THROW(ValueError) << "Cannot find where to insert Relax.Function"; } } @@ -114,13 +116,13 @@ void BindingBlockFrameNode::EnterWithScope() { if (block_frame.defined()) { block_frame.value()->ExitWithScope(); // Block frames cannot appear consecutively. - ICHECK(!IRBuilder::Current()->GetLastFrame()); + TVM_FFI_ICHECK(!IRBuilder::Current()->GetLastFrame()); } // Step 2. Deal with the new block frame. RelaxFrameNode::EnterWithScope(); ffi::Optional func_frame = IRBuilder::Current()->FindFrame(); - CHECK(func_frame.defined()) - << "ValueError: Cannot find FunctionFrame when creating BindingBlocks, Please ensure " + TVM_FFI_CHECK(func_frame.defined(), ValueError) + << "Cannot find FunctionFrame when creating BindingBlocks, Please ensure " "creating the block under Relax function scope."; const tvm::relax::BlockBuilder& block_builder = func_frame.value()->block_builder; if (is_dataflow) { @@ -188,21 +190,22 @@ void BindingBlockFrameNode::ExitWithScope() { // Step 3. Get the last frame from the IRBuilder frame stack. ffi::Optional opt_last_frame = IRBuilder::Current()->GetLastFrame(); - ICHECK(opt_last_frame.defined()); + TVM_FFI_ICHECK(opt_last_frame.defined()); RelaxFrame last_frame = opt_last_frame.value(); // Step 4. Since we popped out any possible block frame when entering the "with" scope of the // current frame, the last frame cannot be a block frame. - ICHECK(!last_frame->IsInstance()); + TVM_FFI_ICHECK(!last_frame->IsInstance()); // Step 5. Push the block frame into the corresponding field of the last frame. if (const auto* seq_frame = last_frame.as()) { auto frame = ffi::GetRef(seq_frame); frame->binding_blocks.push_back(block); } else { - LOG(FATAL) << "ValueError: Currently the last frame is supposed to be either a function frame " - "or a block frame. However, the last frame is \"" - << last_frame->GetTypeKey() << "\"."; + TVM_FFI_THROW(ValueError) + << "Currently the last frame is supposed to be either a function frame " + "or a block frame. However, the last frame is \"" + << last_frame->GetTypeKey() << "\"."; } // Step 6. Start another binding block when a dataflow block ended. @@ -216,7 +219,7 @@ void IfFrameNode::EnterWithScope() { for (const IRBuilderFrame& frame : frames) { const auto* block_frame = frame.as(); if (block_frame && block_frame->is_dataflow) { - LOG(FATAL) << "ValueError: Cannot create an IfFrame inside a dataflow block."; + TVM_FFI_THROW(ValueError) << "Cannot create an IfFrame inside a dataflow block."; } } RelaxFrameNode::EnterWithScope(); @@ -224,10 +227,10 @@ void IfFrameNode::EnterWithScope() { void IfFrameNode::ExitWithScope() { RelaxFrameNode::ExitWithScope(); - CHECK(then_expr.defined()) - << "ValueError: The body of then part is expected to be defined before exiting."; - CHECK(then_expr.defined()) - << "ValueError: The body of else part is expected to be defined before exiting."; + TVM_FFI_CHECK(then_expr.defined(), ValueError) + << "The body of then part is expected to be defined before exiting."; + TVM_FFI_CHECK(then_expr.defined(), ValueError) + << "The body of else part is expected to be defined before exiting."; auto body = tvm::relax::If(condition, then_expr.value(), else_expr.value()); var = Emit(body); IRBuilder::Name(var_name, var); @@ -235,9 +238,8 @@ void IfFrameNode::ExitWithScope() { void ThenFrameNode::EnterWithScope() { IfFrame frame = FindIfFrame("R.Then"); - CHECK(!frame->then_expr.defined()) - << "ValueError: Duplicate then branch declaration, previous one is " - << frame->then_expr.value(); + TVM_FFI_CHECK(!frame->then_expr.defined(), ValueError) + << "Duplicate then branch declaration, previous one is " << frame->then_expr.value(); SeqExprFrameNode::EnterWithScope(); } @@ -252,10 +254,9 @@ void ThenFrameNode::ExitWithScope() { void ElseFrameNode::EnterWithScope() { IfFrame frame = FindIfFrame("R.Else"); - CHECK(frame->then_expr.defined()) << "The else branch should follow then branch"; - CHECK(!frame->else_expr.defined()) - << "ValueError: Duplicate else branch declaration, previous one is " - << frame->else_expr.value(); + TVM_FFI_ICHECK(frame->then_expr.defined()) << "The else branch should follow then branch"; + TVM_FFI_CHECK(!frame->else_expr.defined(), ValueError) + << "Duplicate else branch declaration, previous one is " << frame->else_expr.value(); SeqExprFrameNode::EnterWithScope(); } @@ -265,7 +266,7 @@ void ElseFrameNode::ExitWithScope() { output = GetSeqExprForBranch(ffi::GetRef(this), &var_name); IfFrame frame = FindIfFrame("R.Else"); frame->else_expr = output; - CHECK(frame->var_name == var_name) + TVM_FFI_ICHECK(frame->var_name == var_name) << "This last binding of both branches must provide the same variable. " << "However, the R.Then branch provides variable " << frame->var_name << ", while the R.Else branch provides variable " << var_name; diff --git a/src/script/ir_builder/relax/ir.cc b/src/script/ir_builder/relax/ir.cc index 55f473a7ba0a..35b44feae399 100644 --- a/src/script/ir_builder/relax/ir.cc +++ b/src/script/ir_builder/relax/ir.cc @@ -80,8 +80,8 @@ tvm::relax::Var Arg(const ffi::String& name, const tvm::relax::StructInfo& struc void FuncName(const ffi::String& name) { FunctionFrame frame = FindFunctionFrame("R.func_name"); if (frame->name.has_value()) { - LOG(FATAL) << "ValueError: Duplicate function name, previous one is: \"" << frame->name.value() - << "\""; + TVM_FFI_THROW(ValueError) << "Duplicate function name, previous one is: \"" + << frame->name.value() << "\""; } frame->name = name; } @@ -90,15 +90,15 @@ void FuncAttrs(ffi::Map attrs) { FunctionFrame frame = FindFunctionFrame("R.func_attr"); for (const auto& [key, value] : attrs) { if (key == tvm::attr::kGlobalSymbol && frame->is_private.value_or(Bool(false))->value) { - LOG(FATAL) << "ValueError: " - << "A private function may not have the kGlobalSymbol (\"" - << tvm::attr::kGlobalSymbol << "\") attribute. " - << "However, a private function specified the global symbol as " << value; + TVM_FFI_THROW(ValueError) << "A private function may not have the kGlobalSymbol (\"" + << tvm::attr::kGlobalSymbol << "\") attribute. " + << "However, a private function specified the global symbol as " + << value; } if (auto prev = frame->attrs.Get(key)) { - LOG(FATAL) << "ValueError: " - << "Duplicate R.func_attr annotation for key = \"" << key << "\". " - << "Previous value was " << prev.value() << ", with later definition as " << value; + TVM_FFI_THROW(ValueError) << "Duplicate R.func_attr annotation for key = \"" << key << "\". " + << "Previous value was " << prev.value() + << ", with later definition as " << value; } else { frame->attrs.Set(key, value); } @@ -108,8 +108,8 @@ void FuncAttrs(ffi::Map attrs) { void FuncRetStructInfo(const tvm::relax::StructInfo& ret_sinfo) { FunctionFrame frame = FindFunctionFrame("R.func_ret_struct_info"); if (frame->ret_struct_info.defined()) { - LOG(FATAL) << "ValueError: Duplicate function return struct info, previous one is:\n " - << frame->ret_struct_info.value(); + TVM_FFI_THROW(ValueError) << "Duplicate function return struct info, previous one is:\n " + << frame->ret_struct_info.value(); } frame->ret_struct_info = ret_sinfo; } @@ -135,8 +135,7 @@ void FuncRetValue(const tvm::relax::Expr& value) { } // Step 2. Add the output value to the function frame. FunctionFrame frame = FindFunctionFrame("return"); - CHECK(!frame->output.defined()) - << "ValueError: " + TVM_FFI_CHECK(!frame->output.defined(), ValueError) << "Relax functions do not support multiple return statement. " << "However, return of " << normalized_value << " occurred after a return of " << frame->output << ". " @@ -177,11 +176,11 @@ void DataflowBlockOutput(const ffi::Array& vars) { // Step 1. Check that we're in a Dataflow block that is not ended. ffi::Optional block_frame = IRBuilder::Current()->GetLastFrame(); - CHECK(block_frame.defined() && block_frame.value()->is_dataflow) - << "ValueError: `R.output` should appear inside a dataflow block. However, the current " + TVM_FFI_CHECK(block_frame.defined() && block_frame.value()->is_dataflow, ValueError) + << "`R.output` should appear inside a dataflow block. However, the current " "innermost block is not a dataflow block."; - CHECK(!block_frame.value()->block_ended) - << "ValueError: It is not allowed for a dataflow block to have multiple output operation."; + TVM_FFI_CHECK(!block_frame.value()->block_ended, ValueError) + << "It is not allowed for a dataflow block to have multiple output operation."; // Step 2. Mark the block frame ended of construction, so that any followup binding after this // mark in the dataflow block will lead to an error. @@ -191,8 +190,9 @@ void DataflowBlockOutput(const ffi::Array& vars) { // block. const ffi::Array& emitted_vars = block_frame.value()->emitted_vars; for (const tvm::relax::Var& var : vars) { - CHECK(std::find(emitted_vars.begin(), emitted_vars.end(), var) != emitted_vars.end()) - << "ValueError: An output variable is not emitted by this dataflow block. Please make sure " + TVM_FFI_CHECK(std::find(emitted_vars.begin(), emitted_vars.end(), var) != emitted_vars.end(), + ValueError) + << "An output variable is not emitted by this dataflow block. Please make sure " "all dataflow block output variables are emitted exactly by this block."; block_frame.value()->output_vars.push_back(var); } @@ -218,7 +218,8 @@ tvm::relax::Var Emit(const tvm::relax::Expr& expr, if (!expr->struct_info_.defined()) { UpdateStructInfo(expr, sinfo); } else { - CHECK(StructInfoBaseCheck(sinfo, GetStructInfo(expr)) != tvm::relax::BaseCheckResult::kFailL0) + TVM_FFI_ICHECK(StructInfoBaseCheck(sinfo, GetStructInfo(expr)) != + tvm::relax::BaseCheckResult::kFailL0) << "Invalid annotation. Got rhs value struct info: " << GetStructInfo(expr) << ", given struct info: " << sinfo; } diff --git a/src/script/ir_builder/relax/utils.h b/src/script/ir_builder/relax/utils.h index 30ca9753d497..a2204ef54ee4 100644 --- a/src/script/ir_builder/relax/utils.h +++ b/src/script/ir_builder/relax/utils.h @@ -35,8 +35,8 @@ inline FunctionFrame FindFunctionFrame(const ffi::String& method) { if (ffi::Optional frame = IRBuilder::Current()->FindFrame()) { return frame.value(); } - LOG(FATAL) << "ValueError: Function frame not find. Please ensure '" << method - << "' is called under R.function()"; + TVM_FFI_THROW(ValueError) << "Function frame not find. Please ensure '" << method + << "' is called under R.function()"; throw; } @@ -44,16 +44,16 @@ inline IfFrame FindIfFrame(const ffi::String& method) { if (ffi::Optional frame = IRBuilder::Current()->GetLastFrame()) { return frame.value(); } else { - LOG(FATAL) << "ValueError: IfThenElse frame not find. Please ensure '" << method - << "' is called under R.if_()"; + TVM_FFI_THROW(ValueError) << "IfThenElse frame not find. Please ensure '" << method + << "' is called under R.if_()"; } throw; } inline tvm::relax::BlockBuilder GetBlockBuilder() { ffi::Optional frame = IRBuilder::Current()->FindFrame(); - CHECK(frame.defined()) << "ValueError: Relax Function frame not find. Please ensure " - "assignment is called under R.function()"; + TVM_FFI_CHECK(frame.defined(), ValueError) << "Relax Function frame not find. Please ensure " + "assignment is called under R.function()"; return frame.value()->block_builder; } @@ -63,9 +63,9 @@ inline BindingBlockFrame CheckBindingBlockFrameExistAndUnended() { ffi::Optional block_frame = IRBuilder::Current()->GetLastFrame(); - CHECK(block_frame.defined()) << "ValueError: Block frame not find"; - CHECK(!block_frame.value()->block_ended) - << "ValueError: New binding is not allowed after dataflow block output."; + TVM_FFI_CHECK(block_frame.defined(), ValueError) << "Block frame not find"; + TVM_FFI_CHECK(!block_frame.value()->block_ended, ValueError) + << "New binding is not allowed after dataflow block output."; return block_frame.value(); } @@ -80,14 +80,14 @@ inline tvm::relax::SeqExpr GetSeqExprForBranch(const SeqExprFrame& frame, ffi::S method = "R.Else"; output_var_suffix = "_else"; } else { - ICHECK(false) << "TypeError: Unsupported frame type: " << frame->GetTypeKey(); + TVM_FFI_CHECK(false, TypeError) << "Unsupported frame type: " << frame->GetTypeKey(); } // Step 1. Check non-empty block and last binding is non-dataflow - CHECK(!frame->binding_blocks.empty()) + TVM_FFI_ICHECK(!frame->binding_blocks.empty()) << "Empty body is not allowed for '" << method << "' statements."; const tvm::relax::BindingBlock& last_block = frame->binding_blocks.back(); - CHECK(!last_block->bindings.empty()) << "Blocks are expected to be non-empty."; + TVM_FFI_ICHECK(!last_block->bindings.empty()) << "Blocks are expected to be non-empty."; // Step 2. Update the last binding of each branch. While we could // use the last bound value of each branch as a SeqExpr body, the @@ -96,7 +96,7 @@ inline tvm::relax::SeqExpr GetSeqExprForBranch(const SeqExprFrame& frame, ffi::S // variable name. tvm::relax::Binding last_binding = last_block->bindings.back(); - CHECK(!last_binding->var->IsInstance()) + TVM_FFI_ICHECK(!last_binding->var->IsInstance()) << "A non-dataflow var is expected in the last binding of '" << method << "'."; *var_name = last_binding->var->name_hint(); @@ -123,7 +123,7 @@ inline tvm::relax::SeqExpr GetSeqExprForBranch(const SeqExprFrame& frame, ffi::S tvm::relax::MatchCast(new_var, match_cast->value, match_cast->struct_info)); body = new_var; } else { - ICHECK(false) << "TypeError: Unsupported binding type: " << last_binding->GetTypeKey(); + TVM_FFI_CHECK(false, TypeError) << "Unsupported binding type: " << last_binding->GetTypeKey(); } new_blocks.push_back(last_block->IsInstance() diff --git a/src/script/ir_builder/tir/frame.cc b/src/script/ir_builder/tir/frame.cc index abc981de1f45..e5008c74ed70 100644 --- a/src/script/ir_builder/tir/frame.cc +++ b/src/script/ir_builder/tir/frame.cc @@ -63,11 +63,12 @@ void PrimFuncFrameNode::ExitWithScope() { func = tvm::tir::ScriptComplete(func, root_alloc_buffers); IRBuilder builder = IRBuilder::Current(); if (builder->frames.empty()) { - ICHECK(!builder->result.defined()) << "ValueError: Builder.result has already been set"; + TVM_FFI_CHECK(!builder->result.defined(), ValueError) << "Builder.result has already been set"; builder->result = func; } else if (ffi::Optional opt_frame = builder->FindFrame()) { - CHECK(name.has_value()) << "ValueError: The function name must be defined before exiting the " - "function scope, if it's defined in a Module"; + TVM_FFI_CHECK(name.has_value(), ValueError) + << "The function name must be defined before exiting the " + "function scope, if it's defined in a Module"; const ir::IRModuleFrame& frame = opt_frame.value(); const ffi::String& func_name = name.value_or(""); if (!frame->global_var_map.count(func_name)) { @@ -78,7 +79,7 @@ void PrimFuncFrameNode::ExitWithScope() { // Note we do checks to disallow redefinition of functions inside the `DefFunction`. ir::DefFunction(func_name, func); } else { - LOG(FATAL) << "ValueError: Cannot find where to insert PrimFunc"; + TVM_FFI_THROW(ValueError) << "Cannot find where to insert PrimFunc"; } } @@ -96,9 +97,10 @@ void SBlockFrameNode::ExitWithScope() { writes.value_or(ffi::Array()), name, AsStmt(stmts), init, tir_alloc_buffers, match_buffers, attrs); if (no_realize) { - CHECK(iter_values.empty()) - << "ValueError: Block bindings are not allowed when `no_realize=True`"; - CHECK(!predicate.defined()) << "ValueError: `T.where` is not allowed when `no_realize=True`"; + TVM_FFI_CHECK(iter_values.empty(), ValueError) + << "Block bindings are not allowed when `no_realize=True`"; + TVM_FFI_CHECK(!predicate.defined(), ValueError) + << "`T.where` is not allowed when `no_realize=True`"; AddToParent(block); } else { AddToParent(tvm::tir::SBlockRealize(iter_values, predicate.value_or(Bool(true)), block)); @@ -108,7 +110,7 @@ void SBlockFrameNode::ExitWithScope() { void BlockInitFrameNode::EnterWithScope() { SBlockFrame frame = FindSBlockFrame("T.init"); if (frame->init.defined()) { - LOG(FATAL) << "ValueError: Duplicate block init declaration"; + TVM_FFI_THROW(ValueError) << "Duplicate block init declaration"; } TIRFrameNode::EnterWithScope(); } @@ -158,10 +160,11 @@ void WhileFrameNode::ExitWithScope() { void IfFrameNode::ExitWithScope() { TIRFrameNode::ExitWithScope(); if (!stmts.empty()) { - LOG(FATAL) << "stmt within IfThenElse frame should be either in ThenFrame or ElseFrame"; + TVM_FFI_THROW(InternalError) + << "stmt within IfThenElse frame should be either in ThenFrame or ElseFrame"; } if (!then_stmts.defined()) { - LOG(FATAL) << "IfThenElse frame should have at least one then branch"; + TVM_FFI_THROW(InternalError) << "IfThenElse frame should have at least one then branch"; } AddToParent(tvm::tir::IfThenElse( condition, AsStmt(then_stmts.value()), @@ -171,8 +174,8 @@ void IfFrameNode::ExitWithScope() { void ThenFrameNode::EnterWithScope() { IfFrame frame = FindIfFrame("T.then_"); if (frame->then_stmts.defined()) { - LOG(FATAL) << "ValueError: Duplicate then branch declaration, previous one is " - << frame->then_stmts.value(); + TVM_FFI_THROW(ValueError) << "Duplicate then branch declaration, previous one is " + << frame->then_stmts.value(); } TIRFrameNode::EnterWithScope(); } @@ -185,11 +188,11 @@ void ThenFrameNode::ExitWithScope() { void ElseFrameNode::EnterWithScope() { IfFrame frame = FindIfFrame("T.else_"); if (!frame->then_stmts.defined()) { - LOG(FATAL) << "The else branch should follow then branch"; + TVM_FFI_THROW(InternalError) << "The else branch should follow then branch"; } if (frame->else_stmts.defined()) { - LOG(FATAL) << "ValueError: Duplicate else branch declaration, previous one is " - << frame->else_stmts.value(); + TVM_FFI_THROW(ValueError) << "Duplicate else branch declaration, previous one is " + << frame->else_stmts.value(); } TIRFrameNode::EnterWithScope(); } diff --git a/src/script/ir_builder/tir/ir.cc b/src/script/ir_builder/tir/ir.cc index cd8838f70102..a3439813559f 100644 --- a/src/script/ir_builder/tir/ir.cc +++ b/src/script/ir_builder/tir/ir.cc @@ -35,8 +35,9 @@ Buffer BufferDecl(ffi::Array shape, DataType dtype, ffi::String buffer ffi::Optional elem_offset, ffi::String storage_scope, int align, int offset_factor, ffi::String buffer_type, ffi::Optional> axis_separators) { - CHECK(buffer_type == "auto" || buffer_type == "default" || buffer_type.empty()) - << "ValueError: `buffer_type` must be `auto` or `default` or empty"; + TVM_FFI_CHECK(buffer_type == "auto" || buffer_type == "default" || buffer_type.empty(), + ValueError) + << "`buffer_type` must be `auto` or `default` or empty"; Var buffer_data; if (!data.defined()) { DataType storage_dtype = dtype; @@ -89,7 +90,8 @@ Buffer Arg(ffi::String name, Buffer buffer) { void FuncName(ffi::String name) { PrimFuncFrame frame = FindPrimFuncFrame("T.func_name"); if (frame->name.has_value()) { - LOG(FATAL) << "ValueError: Duplicate prim func name, previous one is " << frame->name.value(); + TVM_FFI_THROW(ValueError) << "Duplicate prim func name, previous one is " + << frame->name.value(); } frame->name = name; } @@ -99,16 +101,16 @@ void FuncAttrs(ffi::Map new_attrs) { PrimFuncFrame frame = FindPrimFuncFrame("T.func_attr"); for (const auto& [key, value] : new_attrs) { if (key == tvm::attr::kGlobalSymbol && frame->is_private) { - LOG(FATAL) << "ValueError: " - << "A private function may not have the kGlobalSymbol (\"" - << tvm::attr::kGlobalSymbol << "\") attribute. " - << "However, a private function specified the global symbol as " << value; + TVM_FFI_THROW(ValueError) << "A private function may not have the kGlobalSymbol (\"" + << tvm::attr::kGlobalSymbol << "\") attribute. " + << "However, a private function specified the global symbol as " + << value; } if (auto prev = frame->attrs.Get(key)) { - LOG(FATAL) << "ValueError: " - << "Duplicate prim func annotation for key = \"" << key << "\". " - << "Previous value was " << prev.value() << ", with later definition as " << value; + TVM_FFI_THROW(ValueError) << "Duplicate prim func annotation for key = \"" << key << "\". " + << "Previous value was " << prev.value() + << ", with later definition as " << value; } else { frame->attrs.Set(key, value); } @@ -118,8 +120,8 @@ void FuncAttrs(ffi::Map new_attrs) { tvm::Type FuncRet(tvm::Type ret_type) { PrimFuncFrame frame = FindPrimFuncFrame("T.ret_type"); if (frame->ret_type.defined()) { - LOG(FATAL) << "ValueError: Duplicate prim func return type, previous one is " - << frame->ret_type.value(); + TVM_FFI_THROW(ValueError) << "Duplicate prim func return type, previous one is " + << frame->ret_type.value(); } frame->ret_type = ret_type; return ret_type; @@ -140,7 +142,7 @@ Buffer MatchBuffer(ObjectRef param, ffi::Array shape, DataType dtype, return buffer; } } - LOG(FATAL) << "ValueError: Can not bind non-input param to buffer."; + TVM_FFI_THROW(ValueError) << "Can not bind non-input param to buffer."; } else if (const auto* buffer_load = param.as()) { SBlockFrame frame = FindSBlockFrame("T.match_buffer"); frame->match_buffers.push_back(tvm::tir::MatchBufferRegion( @@ -150,7 +152,7 @@ Buffer MatchBuffer(ObjectRef param, ffi::Array shape, DataType dtype, frame->match_buffers.push_back( tvm::tir::MatchBufferRegion(buffer, ffi::GetRef(buffer_region))); } else { - LOG(FATAL) << "ValueError: Unexpected type for TIR MatchBuffer."; + TVM_FFI_THROW(ValueError) << "Unexpected type for TIR MatchBuffer."; } return buffer; } @@ -176,8 +178,8 @@ BlockInitFrame Init() { return BlockInitFrame(ffi::make_objectpredicate.defined()) { - LOG(FATAL) << "ValueError: Duplicate block predicate declaration, previous one is " - << frame->predicate; + TVM_FFI_THROW(ValueError) << "Duplicate block predicate declaration, previous one is " + << frame->predicate; } frame->predicate = predicate; } @@ -186,7 +188,8 @@ void Reads(ffi::Array buffer_slices) { using namespace tvm::tir; SBlockFrame frame = FindSBlockFrame("T.reads"); if (frame->reads.defined()) { - LOG(FATAL) << "ValueError: Duplicate read region declaration, previous one is " << frame->reads; + TVM_FFI_THROW(ValueError) << "Duplicate read region declaration, previous one is " + << frame->reads; } ffi::Array reads; for (const ObjectRef& obj : buffer_slices) { @@ -195,7 +198,7 @@ void Reads(ffi::Array buffer_slices) { } else if (auto buffer_load = obj.as()) { reads.push_back(BufferRegionFromLoad(buffer_load.value())); } else { - LOG(FATAL) << "Invalid type for buffer reads."; + TVM_FFI_THROW(InternalError) << "Invalid type for buffer reads."; } } frame->reads = reads; @@ -205,8 +208,8 @@ void Writes(ffi::Array buffer_slices) { using namespace tvm::tir; SBlockFrame frame = FindSBlockFrame("T.writes"); if (frame->writes.defined()) { - LOG(FATAL) << "ValueError: Duplicate write region declaration, previous one is " - << frame->writes; + TVM_FFI_THROW(ValueError) << "Duplicate write region declaration, previous one is " + << frame->writes; } ffi::Array writes; for (const ObjectRef& obj : buffer_slices) { @@ -215,7 +218,7 @@ void Writes(ffi::Array buffer_slices) { } else if (auto buffer_load = obj.as()) { writes.push_back(BufferRegionFromLoad(buffer_load.value())); } else { - LOG(FATAL) << "Invalid type for buffer writes."; + TVM_FFI_THROW(InternalError) << "Invalid type for buffer writes."; } } frame->writes = writes; @@ -245,8 +248,9 @@ ffi::Map MergeAnnotations(const ffi::Map& ne } // Case 2.2: the values are not both dicts, check if the keys are the same if (!ffi::AnyEqual()(old_value.value(), value)) { - LOG(FATAL) << "ValueError: Try to merge two annotations with different values for key `" - << key << "`, previous one is " << old_value.value() << ", new one is " << value; + TVM_FFI_THROW(ValueError) << "Try to merge two annotations with different values for key `" + << key << "`, previous one is " << old_value.value() + << ", new one is " << value; } } return result; @@ -275,8 +279,8 @@ Buffer AllocBuffer(ffi::Array shape, DataType dtype, ffi::Optional frame = builder->GetLastFrame()) { frame.value()->root_alloc_buffers.push_back(buffer); } else { - LOG(FATAL) << "ValueError: Block frame or PrimFunc frame not find. Please ensure " - "'T.alloc_buffer' is called under T.sblock() or T.prim_func()"; + TVM_FFI_THROW(ValueError) << "Block frame or PrimFunc frame not find. Please ensure " + "'T.alloc_buffer' is called under T.sblock() or T.prim_func()"; } return buffer; } @@ -288,14 +292,14 @@ IterVar PushBlockVar(IterVar iter_var, PrimExpr binding) { frame->iter_vars.push_back(iter_var); frame->iter_values.push_back(binding); } else { - LOG(FATAL) << "TypeError: The last frame is not SBlockFrame"; + TVM_FFI_THROW(TypeError) << "The last frame is not SBlockFrame"; } return iter_var; } #define TVM_TIR_IR_BUILDER_AXIS(Method, Kind, Name) \ Var Method(Range dom, PrimExpr binding, DataType dtype) { \ - ICHECK(dom.defined()) << Name << " axis must have a domain"; \ + TVM_FFI_ICHECK(dom.defined()) << Name << " axis must have a domain"; \ int bits = std::max({dom->min.dtype().bits(), dom->extent.dtype().bits(), dtype.bits()}); \ return PushBlockVar(IterVar(/*dom=*/dom, /*var=*/Var("", dtype.with_bits(bits)), \ /*iter_type=*/Kind, /*thread_tag=*/""), \ @@ -311,18 +315,18 @@ TVM_TIR_IR_BUILDER_AXIS(Opaque, tvm::tir::IterVarType::kOpaque, "Opaque"); ffi::Array Remap(ffi::String kinds, ffi::Array bindings, DataType dtype) { using namespace tvm::tir; ffi::Array results; - ICHECK_EQ(kinds.size(), bindings.size()); + TVM_FFI_ICHECK_EQ(kinds.size(), bindings.size()); int n = bindings.size(); results.reserve(n); for (int i = 0; i < n; ++i) { char c = kinds.c_str()[i]; PrimExpr e = bindings[i]; const VarNode* v = e.as(); - ICHECK(v) << "TypeError: Only Var is supported in T.axis.remap"; + TVM_FFI_CHECK(v, TypeError) << "Only Var is supported in T.axis.remap"; Range dom{nullptr}; for (const auto& frame : IRBuilder::Current()->frames) { if (const auto* for_frame = frame.as()) { - ICHECK_EQ(for_frame->doms.size(), for_frame->vars.size()); + TVM_FFI_ICHECK_EQ(for_frame->doms.size(), for_frame->vars.size()); int n = for_frame->doms.size(); for (int i = 0; i < n; ++i) { if (for_frame->vars[i].get() == v) { @@ -335,7 +339,8 @@ ffi::Array Remap(ffi::String kinds, ffi::Array bindings, DataType } } } - ICHECK(dom.defined()) << "TypeError: Variable is not in the loop: " << ffi::GetRef(v); + TVM_FFI_CHECK(dom.defined(), TypeError) + << "Variable is not in the loop: " << ffi::GetRef(v); DataType dtype = v->dtype; if (c == 'S') { results.push_back(PushBlockVar(IterVar(/*dom=*/dom, @@ -352,7 +357,7 @@ ffi::Array Remap(ffi::String kinds, ffi::Array bindings, DataType e) ->var); } else { - LOG(FATAL) << "Unknown axis kind: " << c; + TVM_FFI_THROW(InternalError) << "Unknown axis kind: " << c; } } return results; @@ -374,9 +379,9 @@ ffi::Array Remap(ffi::String kinds, ffi::Array bindings, DataType n->f_make_for_loop = [annotations](ffi::Array vars, ffi::Array doms, \ ffi::Array> steps, \ tvm::tir::Stmt body) { \ - ICHECK_EQ(vars.size(), 1); \ - ICHECK_EQ(doms.size(), 1); \ - ICHECK_EQ(steps.size(), 1); \ + TVM_FFI_ICHECK_EQ(vars.size(), 1); \ + TVM_FFI_ICHECK_EQ(doms.size(), 1); \ + TVM_FFI_ICHECK_EQ(steps.size(), 1); \ return tvm::tir::For(vars[0], doms[0]->min, doms[0]->extent, Kind, body, std::nullopt, \ annotations.value_or(ffi::Map()), steps[0]); \ }; \ @@ -404,9 +409,9 @@ ForFrame ThreadBinding(PrimExpr start, PrimExpr stop, ffi::String thread, n->f_make_for_loop = [annotations, thread, dtype](ffi::Array vars, ffi::Array doms, ffi::Array> steps, Stmt body) -> For { - ICHECK_EQ(vars.size(), 1); - ICHECK_EQ(doms.size(), 1); - ICHECK(steps.size() == 1 && (!steps[0].has_value() || is_one(*steps[0]))); + TVM_FFI_ICHECK_EQ(vars.size(), 1); + TVM_FFI_ICHECK_EQ(doms.size(), 1); + TVM_FFI_ICHECK(steps.size() == 1 && (!steps[0].has_value() || is_one(*steps[0]))); IterVar iter_var(Range(nullptr), Var("iter", dtype), IterVarType::kThreadIndex, thread); return For(vars[0], doms[0]->min, doms[0]->extent, ForKind::kThreadBinding, body, iter_var, annotations.value_or(ffi::Map()), std::nullopt); @@ -427,8 +432,8 @@ ForFrame Grid(ffi::Array extents) { } n->f_make_for_loop = [](ffi::Array vars, ffi::Array doms, ffi::Array> steps, Stmt body) -> Stmt { - ICHECK_EQ(vars.size(), doms.size()); - ICHECK_EQ(vars.size(), steps.size()); + TVM_FFI_ICHECK_EQ(vars.size(), doms.size()); + TVM_FFI_ICHECK_EQ(vars.size(), steps.size()); int n = vars.size(); for (int i = n - 1; i >= 0; --i) { Range dom = doms[i]; @@ -475,19 +480,19 @@ LaunchThreadFrame LaunchThread(Var var, PrimExpr extent) { if (ffi::Optional opt_iter_var = opt_frame.value()->env_threads.Get(var)) { iter_var = opt_iter_var.value(); } else { - LOG(FATAL) << "ValueError: " << var->name_hint - << " is not an env_thread created using T.env_thread."; + TVM_FFI_THROW(ValueError) << var->name_hint + << " is not an env_thread created using T.env_thread."; } } else { - LOG(FATAL) << "LaunchThread can only be used inside a PrimFunc"; + TVM_FFI_THROW(InternalError) << "LaunchThread can only be used inside a PrimFunc"; } ObjectPtr n = ffi::make_object(); if (!iter_var->dom.defined()) { const_cast(iter_var.get())->dom = Range(tvm::tir::make_zero(extent.dtype()), extent); } else if (!arith::Analyzer().CanProveEqual(iter_var->dom->extent, extent)) { - LOG(FATAL) << "ValueError: Inconsistent extents of environment thread. " - << iter_var->dom->extent << " vs " << extent; + TVM_FFI_THROW(ValueError) << "Inconsistent extents of environment thread. " + << iter_var->dom->extent << " vs " << extent; } n->iter_var = iter_var; n->extent = extent; @@ -554,7 +559,7 @@ Var EnvThread(ffi::String thread_tag, DataType dtype) { if (ffi::Optional opt_frame = IRBuilder::Current()->FindFrame()) { opt_frame.value()->env_threads.Set(var, iter_var); } else { - LOG(FATAL) << "EnvThread can only be used inside a PrimFunc"; + TVM_FFI_THROW(InternalError) << "EnvThread can only be used inside a PrimFunc"; } return var; } @@ -565,7 +570,7 @@ void BufferStore(Buffer buffer, PrimExpr value, ffi::Array indices, bool is_index_scalable = indices.empty() ? false : indices.back().dtype().is_scalable_vector(); bool is_buffer_dtype_scalable = buffer_dtype.is_scalable_vector(); - ICHECK(!(is_index_scalable && is_buffer_dtype_scalable)) + TVM_FFI_ICHECK(!(is_index_scalable && is_buffer_dtype_scalable)) << "Index dtype and buffer dtype can't both be scalable."; int index_lanes; @@ -589,7 +594,7 @@ void BufferStore(Buffer buffer, PrimExpr value, ffi::Array indices, runtime::DataType rhs_dtype = value->dtype; if (lhs_dtype != rhs_dtype) { - ICHECK(lhs_dtype.is_scalable_vector() == rhs_dtype.is_scalable_vector()) + TVM_FFI_ICHECK(lhs_dtype.is_scalable_vector() == rhs_dtype.is_scalable_vector()) << "Can't mix scalable and fixed length vectors in a statement"; bool lanes_match = false; @@ -600,9 +605,9 @@ void BufferStore(Buffer buffer, PrimExpr value, ffi::Array indices, } if (!lanes_match) { - LOG(FATAL) << "TypeError: Incompatible types in BufferStore" - << ": LHS is `" << lhs_dtype << "`, RHS is `" << rhs_dtype - << "`, indexing lanes: " << index_lanes; + TVM_FFI_THROW(TypeError) << "Incompatible types in BufferStore" + << ": LHS is `" << lhs_dtype << "`, RHS is `" << rhs_dtype + << "`, indexing lanes: " << index_lanes; } if (lhs_dtype.code() != rhs_dtype.code()) { if ( @@ -697,7 +702,7 @@ TVM_FFI_STATIC_INIT_BLOCK() { if (auto buffer = obj.as()) { return Arg(name, buffer.value()); } - LOG(FATAL) << "ValueError: Unexpected type for TIR Arg: " << obj->GetTypeKey(); + TVM_FFI_THROW(ValueError) << "Unexpected type for TIR Arg: " << obj->GetTypeKey(); throw; }) .def("script.ir_builder.tir.FuncName", FuncName) @@ -739,8 +744,8 @@ TVM_FFI_STATIC_INIT_BLOCK() { } else if (auto str = thread_tag_or_var.as()) { return LaunchThread(str.value(), extent); } else { - LOG(FATAL) << "ValueError: Unexpected type for TIR LaunchThread: " - << thread_tag_or_var.GetTypeKey(); + TVM_FFI_THROW(ValueError) + << "Unexpected type for TIR LaunchThread: " << thread_tag_or_var.GetTypeKey(); throw; } }) diff --git a/src/script/ir_builder/tir/utils.h b/src/script/ir_builder/tir/utils.h index cabf418a10af..a53356b494ed 100644 --- a/src/script/ir_builder/tir/utils.h +++ b/src/script/ir_builder/tir/utils.h @@ -36,12 +36,12 @@ namespace tir { inline void AddToParent(tvm::tir::Stmt stmt) { IRBuilder builder = IRBuilder::Current(); if (builder->frames.empty()) { - ICHECK(!builder->result.defined()) << "ValueError: Builder.result has already been set"; + TVM_FFI_CHECK(!builder->result.defined(), ValueError) << "Builder.result has already been set"; builder->result = stmt; } else if (const auto* tir_frame = builder->frames.back().as()) { ffi::GetRef(tir_frame)->stmts.push_back(stmt); } else { - LOG(FATAL) << "TypeError: Unsupported frame type: " << builder->frames.back(); + TVM_FFI_THROW(TypeError) << "Unsupported frame type: " << builder->frames.back(); } } @@ -64,13 +64,14 @@ inline PrimFuncFrame FindPrimFuncFrame(const ffi::String& method) { return frame.value(); } else if (ffi::Optional frame = IRBuilder::Current()->FindFrame()) { - LOG(FATAL) << "ValueError: " << method << " must be called at the top of a PrimFunc. " - << "While " << method << " did occur within the PrimFunc \"" << frame.value()->name - << "\", other frames (e.g. block/if/else/let) had been introduced since the " - << "PrimFunc's frame"; + TVM_FFI_THROW(ValueError) + << method << " must be called at the top of a PrimFunc. " + << "While " << method << " did occur within the PrimFunc \"" << frame.value()->name + << "\", other frames (e.g. block/if/else/let) had been introduced since the " + << "PrimFunc's frame"; } else { - LOG(FATAL) << "ValueError: " << method << " must be called at the top of a PrimFunc, " - << "but " << method << " occurred outside of any T.prim_func() frame"; + TVM_FFI_THROW(ValueError) << method << " must be called at the top of a PrimFunc, " + << "but " << method << " occurred outside of any T.prim_func() frame"; } throw; } @@ -84,13 +85,14 @@ inline SBlockFrame FindSBlockFrame(const ffi::String& method) { if (ffi::Optional frame = IRBuilder::Current()->FindFrame()) { return frame.value(); } else if (ffi::Optional frame = IRBuilder::Current()->FindFrame()) { - LOG(FATAL) << "ValueError: " << method << " must be called at the top of a T.sblock(). " - << "While " << method << " did occur within the block \"" << frame.value()->name - << "\", other frames (e.g. if/else/let) had been introduced since the T.sblock(\"" - << frame.value()->name << "\") frame"; + TVM_FFI_THROW(ValueError) + << method << " must be called at the top of a T.sblock(). " + << "While " << method << " did occur within the block \"" << frame.value()->name + << "\", other frames (e.g. if/else/let) had been introduced since the T.sblock(\"" + << frame.value()->name << "\") frame"; } else { - LOG(FATAL) << "ValueError: " << method << " must be called at the top of a T.sblock(), " - << "but " << method << " occurred outside of any T.sblock() frame"; + TVM_FFI_THROW(ValueError) << method << " must be called at the top of a T.sblock(), " + << "but " << method << " occurred outside of any T.sblock() frame"; } throw; } @@ -104,14 +106,15 @@ inline IfFrame FindIfFrame(const ffi::String& method) { if (ffi::Optional frame = IRBuilder::Current()->GetLastFrame()) { return frame.value(); } else if (ffi::Optional frame = IRBuilder::Current()->FindFrame()) { - LOG(FATAL) << "ValueError: " << method << " must be called at the top of a T.if_(). " - << "While " << method << " did occur within the conditional based on (" - << frame.value()->condition - << "), other frames (e.g. if/else/let) had been introduced since the " - << "IfThenElse frame"; + TVM_FFI_THROW(ValueError) << method << " must be called at the top of a T.if_(). " + << "While " << method + << " did occur within the conditional based on (" + << frame.value()->condition + << "), other frames (e.g. if/else/let) had been introduced since the " + << "IfThenElse frame"; } else { - LOG(FATAL) << "ValueError: IfThenElse frame not find. Please ensure '" << method - << "' is called under T.if_()"; + TVM_FFI_THROW(ValueError) << "IfThenElse frame not find. Please ensure '" << method + << "' is called under T.if_()"; } throw; } diff --git a/src/script/printer/doc.cc b/src/script/printer/doc.cc index e5d72c002da0..b3e767d0cfc4 100644 --- a/src/script/printer/doc.cc +++ b/src/script/printer/doc.cc @@ -164,10 +164,10 @@ SliceDoc::SliceDoc(ffi::Optional start, ffi::Optional stop, } AssignDoc::AssignDoc(ExprDoc lhs, ffi::Optional rhs, ffi::Optional annotation) { - CHECK(rhs.defined() || annotation.defined()) - << "ValueError: At least one of rhs and annotation needs to be non-null for AssignDoc."; - CHECK(lhs->IsInstance() || annotation == nullptr) - << "ValueError: annotation can only be nonnull if lhs is an identifier."; + TVM_FFI_CHECK(rhs.defined() || annotation.defined(), ValueError) + << "At least one of rhs and annotation needs to be non-null for AssignDoc."; + TVM_FFI_CHECK(lhs->IsInstance() || annotation == nullptr, ValueError) + << "annotation can only be nonnull if lhs is an identifier."; ObjectPtr n = ffi::make_object(); n->lhs = lhs; @@ -177,8 +177,8 @@ AssignDoc::AssignDoc(ExprDoc lhs, ffi::Optional rhs, ffi::Optional then_branch, ffi::Array else_branch) { - CHECK(!then_branch.empty() || !else_branch.empty()) - << "ValueError: At least one of the then branch or else branch needs to be non-empty."; + TVM_FFI_CHECK(!then_branch.empty() || !else_branch.empty(), ValueError) + << "At least one of the then branch or else branch needs to be non-empty."; ObjectPtr n = ffi::make_object(); n->predicate = predicate; diff --git a/src/script/printer/doc_printer/base_doc_printer.cc b/src/script/printer/doc_printer/base_doc_printer.cc index 77990c8048c5..ad81297f97be 100644 --- a/src/script/printer/doc_printer/base_doc_printer.cc +++ b/src/script/printer/doc_printer/base_doc_printer.cc @@ -343,7 +343,7 @@ void DocPrinter::PrintDoc(const Doc& doc) { } else if (auto doc_node = doc.as()) { PrintTypedDoc(doc_node.value()); } else { - LOG(FATAL) << "Do not know how to print " << doc->GetTypeKey(); + TVM_FFI_THROW(InternalError) << "Do not know how to print " << doc->GetTypeKey(); throw; } diff --git a/src/script/printer/doc_printer/python_doc_printer.cc b/src/script/printer/doc_printer/python_doc_printer.cc index 5627315c0387..db04e7427acd 100644 --- a/src/script/printer/doc_printer/python_doc_printer.cc +++ b/src/script/printer/doc_printer/python_doc_printer.cc @@ -128,9 +128,9 @@ ExprPrecedence GetExprPrecedence(const ExprDoc& doc) { if (const auto* op_doc = doc.as()) { size_t kind = static_cast(op_doc->kind); - ICHECK_LT(kind, op_kind_precedence.size()) << "ValueError: Invalid operation: " << kind; + TVM_FFI_CHECK_LT(kind, op_kind_precedence.size(), ValueError) << "Invalid operation: " << kind; ExprPrecedence precedence = op_kind_precedence[kind]; - ICHECK(precedence != ExprPrecedence::kUnkown) + TVM_FFI_ICHECK(precedence != ExprPrecedence::kUnkown) << "Precedence for operator " << static_cast(op_doc->kind) << " is unknown"; return precedence; } @@ -138,7 +138,7 @@ ExprPrecedence GetExprPrecedence(const ExprDoc& doc) { if (it != doc_type_precedence.end()) { return it->second; } - ICHECK(false) << "Precedence for doc type " << doc->GetTypeKey() << " is unknown"; + TVM_FFI_ICHECK(false) << "Precedence for doc type " << doc->GetTypeKey() << " is unknown"; throw; } @@ -256,8 +256,8 @@ class PythonDocPrinter : public DocPrinter { if (stmt->comment.has_value()) { const std::string& comment = stmt->comment.value(); bool has_newline = std::find(comment.begin(), comment.end(), '\n') != comment.end(); - CHECK(!has_newline) << "ValueError: the comment string of " << stmt->GetTypeKey() - << " cannot have newline."; + TVM_FFI_CHECK(!has_newline, ValueError) + << "the comment string of " << stmt->GetTypeKey() << " cannot have newline."; size_t start_pos = output_.tellp(); output_ << " # " << comment; size_t end_pos = output_.tellp(); @@ -355,7 +355,7 @@ void PythonDocPrinter::PrintTypedDoc(const LiteralDoc& doc) { } else if (const auto opt_str = value.as()) { output_ << "\"" << support::StrEscape((*opt_str).data(), (*opt_str).size()) << "\""; } else { - LOG(FATAL) << "TypeError: Unsupported literal value type: " << value.GetTypeKey(); + TVM_FFI_THROW(TypeError) << "Unsupported literal value type: " << value.GetTypeKey(); } } @@ -417,10 +417,10 @@ const std::string OperatorToString(OperationDocNode::Kind operation_kind) { }(); auto op_index = static_cast(operation_kind); - ICHECK_LT(op_index, op_kind2str.size()); + TVM_FFI_ICHECK_LT(op_index, op_kind2str.size()); const std::string str = op_kind2str[op_index]; - ICHECK(!str.empty()) << "OperationDocNode::Kind " << static_cast(operation_kind) - << " cannot be converted to operator token in Python directly."; + TVM_FFI_ICHECK(!str.empty()) << "OperationDocNode::Kind " << static_cast(operation_kind) + << " cannot be converted to operator token in Python directly."; return str; } @@ -428,7 +428,7 @@ void PythonDocPrinter::PrintTypedDoc(const OperationDoc& doc) { using OpKind = OperationDocNode::Kind; if (doc->kind < OpKind::kUnaryEnd) { // Unary Operators - ICHECK_EQ(doc->operands.size(), 1); + TVM_FFI_ICHECK_EQ(doc->operands.size(), 1); output_ << OperatorToString(doc->kind); PrintChildExpr(doc->operands[0], doc); } else if (doc->kind == OpKind::kPow) { @@ -436,26 +436,27 @@ void PythonDocPrinter::PrintTypedDoc(const OperationDoc& doc) { // It's right-associative and binds less tightly than unary operator on its right. // https://docs.python.org/3/reference/expressions.html#the-power-operator // https://docs.python.org/3/reference/expressions.html#operator-precedence - ICHECK_EQ(doc->operands.size(), 2); + TVM_FFI_ICHECK_EQ(doc->operands.size(), 2); PrintChildExprConservatively(doc->operands[0], doc); output_ << " ** "; PrintChildExpr(doc->operands[1], ExprPrecedence::kUnary); } else if (doc->kind < OpKind::kBinaryEnd) { // Binary Operator - ICHECK_EQ(doc->operands.size(), 2); + TVM_FFI_ICHECK_EQ(doc->operands.size(), 2); PrintChildExpr(doc->operands[0], doc); output_ << " " << OperatorToString(doc->kind) << " "; PrintChildExprConservatively(doc->operands[1], doc); } else if (doc->kind == OpKind::kIfThenElse) { - ICHECK_EQ(doc->operands.size(), 3) - << "ValueError: IfThenElse requires 3 operands, but got " << doc->operands.size(); + TVM_FFI_CHECK_EQ(doc->operands.size(), 3, ValueError) + << "IfThenElse requires 3 operands, but got " << doc->operands.size(); PrintChildExpr(doc->operands[1], doc); output_ << " if "; PrintChildExprConservatively(doc->operands[0], doc); output_ << " else "; PrintChildExprConservatively(doc->operands[2], doc); } else { - LOG(FATAL) << "Unknown OperationDocNode::Kind " << static_cast(doc->kind); + TVM_FFI_THROW(InternalError) << "Unknown OperationDocNode::Kind " + << static_cast(doc->kind); throw; } } @@ -477,7 +478,7 @@ void PythonDocPrinter::PrintTypedDoc(const CallDoc& doc) { } // Print keyword args - ICHECK_EQ(doc->kwargs_keys.size(), doc->kwargs_values.size()) + TVM_FFI_ICHECK_EQ(doc->kwargs_keys.size(), doc->kwargs_values.size()) << "CallDoc should have equal number of elements in kwargs_keys and kwargs_values."; for (size_t i = 0; i < doc->kwargs_keys.size(); i++) { if (is_first) { @@ -519,7 +520,7 @@ void PythonDocPrinter::PrintTypedDoc(const TupleDoc& doc) { } void PythonDocPrinter::PrintTypedDoc(const DictDoc& doc) { - ICHECK_EQ(doc->keys.size(), doc->values.size()) + TVM_FFI_ICHECK_EQ(doc->keys.size(), doc->values.size()) << "DictDoc should have equal number of elements in keys and values."; output_ << "{"; size_t idx = 0; @@ -664,7 +665,8 @@ void PythonDocPrinter::PrintTypedDoc(const ReturnDoc& doc) { void PythonDocPrinter::PrintTypedDoc(const FunctionDoc& doc) { for (const AssignDoc& arg_doc : doc->args) { - ICHECK(!arg_doc->comment.has_value()) << "Function arg cannot have comment attached to them."; + TVM_FFI_ICHECK(!arg_doc->comment.has_value()) + << "Function arg cannot have comment attached to them."; } PrintDecorators(doc->decorators); diff --git a/src/script/printer/ir/ir.cc b/src/script/printer/ir/ir.cc index dff435833778..ee1dcde1035f 100644 --- a/src/script/printer/ir/ir.cc +++ b/src/script/printer/ir/ir.cc @@ -42,8 +42,8 @@ struct SortableFunction { } else if (obj.second->GetTypeKey() == "relax.expr.Function") { priority = 3; } else { - LOG(FATAL) << "TypeError: TVMScript cannot print functions of type: " - << obj.second->GetTypeKey(); + TVM_FFI_THROW(TypeError) << "TVMScript cannot print functions of type: " + << obj.second->GetTypeKey(); } } @@ -103,10 +103,10 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) AssignDoc assignment(lhs, expr.value(), std::nullopt); (*f)->stmts.push_back(assignment); } else { - LOG(FATAL) << "TypeError: " - << "Expected IRModule to only contain functions, " - << " but mod[" << gv->name_hint << "] with type " << base_func->GetTypeKey() - << " produced Doc type of " << doc->GetTypeKey(); + TVM_FFI_THROW(TypeError) + << "Expected IRModule to only contain functions, " + << " but mod[" << gv->name_hint << "] with type " << base_func->GetTypeKey() + << " produced Doc type of " << doc->GetTypeKey(); } } return HeaderWrapper(d, ClassDoc(module_doc, {IR(d, "ir_module")}, (*f)->stmts)); diff --git a/src/script/printer/ir_docsifier.cc b/src/script/printer/ir_docsifier.cc index 8ebbedfef78d..ee204ccee469 100644 --- a/src/script/printer/ir_docsifier.cc +++ b/src/script/printer/ir_docsifier.cc @@ -68,7 +68,7 @@ IdDoc IRDocsifierNode::Define(const ObjectRef& obj, const Frame& frame, } void IRDocsifierNode::Define(const ObjectRef& obj, const Frame& frame, DocCreator doc_factory) { - ICHECK(obj2info.find(obj) == obj2info.end()) << "Duplicated object: " << obj; + TVM_FFI_ICHECK(obj2info.find(obj) == obj2info.end()) << "Duplicated object: " << obj; obj2info.insert({obj, VariableInfo{std::move(doc_factory), std::nullopt}}); frame->AddExitCallback([this, obj]() { this->RemoveVar(obj); }); } @@ -82,7 +82,7 @@ ffi::Optional IRDocsifierNode::GetVarDoc(const ObjectRef& obj) const { } ExprDoc IRDocsifierNode::AddMetadata(const ffi::Any& obj) { - ICHECK(obj != nullptr) << "TypeError: Cannot add nullptr to metadata"; + TVM_FFI_CHECK(obj != nullptr, TypeError) << "Cannot add nullptr to metadata"; ffi::String key = obj.GetTypeKey(); ffi::Array& array = metadata[key]; int index = std::find_if(array.begin(), array.end(), @@ -96,7 +96,7 @@ ExprDoc IRDocsifierNode::AddMetadata(const ffi::Any& obj) { } void IRDocsifierNode::AddGlobalInfo(const ffi::String& name, const GlobalInfo& ginfo) { - ICHECK(ginfo.defined()) << "TypeError: Cannot add nullptr to global_infos"; + TVM_FFI_CHECK(ginfo.defined(), TypeError) << "Cannot add nullptr to global_infos"; ffi::Array& array = global_infos[name]; array.push_back(ginfo); } @@ -105,7 +105,7 @@ bool IRDocsifierNode::IsVarDefined(const ObjectRef& obj) const { return obj2info void IRDocsifierNode::RemoveVar(const ObjectRef& obj) { auto it = obj2info.find(obj); - ICHECK(it != obj2info.end()) << "No such object: " << obj; + TVM_FFI_ICHECK(it != obj2info.end()) << "No such object: " << obj; if (it->second.name.has_value()) { defined_names.erase(it->second.name.value()); } diff --git a/src/script/printer/relax/call.cc b/src/script/printer/relax/call.cc index 6d96327e2db4..f2baac0f5375 100644 --- a/src/script/printer/relax/call.cc +++ b/src/script/printer/relax/call.cc @@ -41,7 +41,7 @@ class AttrPrinter { } } else { const TVMFFITypeInfo* attrs_tinfo = TVMFFIGetTypeInfo(attrs->type_index()); - ICHECK(attrs_tinfo->metadata != nullptr) + TVM_FFI_ICHECK(attrs_tinfo->metadata != nullptr) << "Object `" << attrs->GetTypeKey() << "` misses reflection registration and do not support serialization"; // new printing mechanism using the new reflection @@ -81,8 +81,8 @@ ffi::Optional PrintCallTIRDPSPacked(const relax::Call& n, const AccessP !n->op.same_as(call_tir_inplace_op)) { return std::nullopt; } - ICHECK(n->args.size() == 2 || n->args.size() == 3); - ICHECK(n->sinfo_args.size() == 1); + TVM_FFI_ICHECK(n->args.size() == 2 || n->args.size() == 3); + TVM_FFI_ICHECK(n->sinfo_args.size() == 1); ffi::Array args; ffi::Array kwargs_keys; ffi::Array kwargs_values; @@ -166,7 +166,7 @@ ffi::Optional PrintAssertOp(const relax::Call& n, const AccessPath& n_p if (!n->op.same_as(assert_op)) { return std::nullopt; } - ICHECK(n->args.size() >= 2); + TVM_FFI_ICHECK(n->args.size() >= 2); // special handling: it is important to indicate that the format string (second argument) // is the _format_ string, or else roundtripping will fail // (the format string will be interpreted as an argument and there will be a new default format @@ -191,7 +191,7 @@ ffi::Optional PrintHintOnDevice(const relax::Call& n, const AccessPath& args.push_back(PrintCallee(n->args[0], n_p->Attr("args")->ArrayItem(0), d)); ffi::Array kwargs_keys; ffi::Array kwargs_values; - ICHECK(n->attrs.defined()); + TVM_FFI_ICHECK(n->attrs.defined()); if (n->attrs.as()) { AttrPrinter(n_p->Attr("attrs"), d, &kwargs_keys, &kwargs_values)(n->attrs); ExprDoc scope_val = kwargs_values.back(); @@ -214,7 +214,7 @@ ffi::Optional PrintToVDevice(const relax::Call& n, const AccessPath& n_ args.push_back(PrintCallee(n->args[0], n_p->Attr("args")->ArrayItem(0), d)); ffi::Array kwargs_keys; ffi::Array kwargs_values; - ICHECK(n->attrs.defined()); + TVM_FFI_ICHECK(n->attrs.defined()); if (const auto* attrs = n->attrs.as()) { VDevice vdev = attrs->dst_vdevice; std::string dev_kind = vdev->target->kind->name; @@ -233,7 +233,7 @@ ffi::Optional PrintRelaxPrint(const relax::Call& n, const AccessPath& n if (!n->op.same_as(print_op)) { return std::nullopt; } - ICHECK(n->args.size() >= 1); + TVM_FFI_ICHECK(n->args.size() >= 1); // special handling: it is important to indicate that the format string (first argument) // is the _format_ string, or else roundtripping will fail // (the format string will be interpreted as an argument and there will be a new default format @@ -289,7 +289,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) n->op->IsInstance()) { prefix = d->AsDoc(n->op, n_p->Attr("op")); } else { - LOG(FATAL) << "TypeError: Unsupported op: " << n->op->GetTypeKey(); + TVM_FFI_THROW(TypeError) << "Unsupported op: " << n->op->GetTypeKey(); } // Step 2. Print args if (!n->args.empty()) { diff --git a/src/script/printer/relax/distributed.cc b/src/script/printer/relax/distributed.cc index d1a29be24f5e..51fae05bf626 100644 --- a/src/script/printer/relax/distributed.cc +++ b/src/script/printer/relax/distributed.cc @@ -120,7 +120,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) } } } - LOG(FATAL) << "Cannot find device mesh in global infos"; + TVM_FFI_THROW(InternalError) << "Cannot find device mesh in global infos"; } }); diff --git a/src/script/printer/relax/region.cc b/src/script/printer/relax/region.cc index a28967cb4194..c42d2d5567c5 100644 --- a/src/script/printer/relax/region.cc +++ b/src/script/printer/relax/region.cc @@ -35,7 +35,7 @@ ffi::Array PrintSeqExpr(const relax::SeqExpr& n, const AccessPath& n_p, } else if (const auto* stmt = block.as()) { stmts->push_back(ffi::GetRef(stmt)); } else { - LOG(FATAL) << "TypeError: Unknown type: " << block->GetTypeKey(); + TVM_FFI_THROW(TypeError) << "Unknown type: " << block->GetTypeKey(); } } ExprDoc ret = d->AsDoc(n->body, n_p->Attr("body")); @@ -61,14 +61,14 @@ ffi::Array PrintBindingBlock(const relax::BindingBlock& n, const Access for (int i = 0, l = bindings.size(); i < l; ++i) { const relax::Binding& binding = bindings[i]; AccessPath binding_p = bindings_p->ArrayItem(i); - ICHECK(binding->var.defined()); + TVM_FFI_ICHECK(binding->var.defined()); Doc binding_doc = d->AsDoc(binding, binding_p); if (const auto* stmt = binding_doc.as()) { stmts.push_back(ffi::GetRef(stmt)); } else if (const auto* stmt_block = binding_doc.as()) { stmts.insert(stmts.end(), stmt_block->stmts.begin(), stmt_block->stmts.end()); } else { - LOG(FATAL) << "TypeError: Unknown type: " << binding_doc->GetTypeKey(); + TVM_FFI_THROW(TypeError) << "Unknown type: " << binding_doc->GetTypeKey(); } if (non_dataflow_vars != nullptr && !binding->var->IsInstance()) { non_dataflow_vars->push_back(d->AsDoc(binding->var, binding_p->Attr("var"))); diff --git a/src/script/printer/relax/tir.cc b/src/script/printer/relax/tir.cc index 0c1a2cd26035..ae35f018cbfe 100644 --- a/src/script/printer/relax/tir.cc +++ b/src/script/printer/relax/tir.cc @@ -42,20 +42,20 @@ RelaxFrameNode* GetRelaxFrame(IRDocsifier d) { } Doc PrintTIRVar(tir::Var n, AccessPath n_p, IRDocsifier d) { - ICHECK(n->dtype.is_scalar()) << "TypeError: " - << "Relax only uses scalar TIR variables," - << "but received TIR variable " << n << " with dtype " << n->dtype; + TVM_FFI_CHECK(n->dtype.is_scalar(), TypeError) + << "Relax only uses scalar TIR variables," + << "but received TIR variable " << n << " with dtype " << n->dtype; if (!d->IsVarDefined(n)) { RelaxFrameNode* f = GetRelaxFrame(d); // There should be at least one Relax frame if (f == nullptr) { - LOG(FATAL) << "IndexError: No relax environment is found when printing a TIR var under " - "relax's dispatch token"; + TVM_FFI_THROW(IndexError) << "No relax environment is found when printing a TIR var under " + "relax's dispatch token"; } // If the Relax function frame is collecting func vars if (f->func_vars) { - ICHECK(f->is_func); + TVM_FFI_ICHECK(f->is_func); f->func_vars->insert(n.get()); } IdDoc var = d->Define(n, ffi::GetRef(f), n->name_hint.empty() ? "v" : n->name_hint); @@ -65,7 +65,7 @@ Doc PrintTIRVar(tir::Var n, AccessPath n_p, IRDocsifier d) { if (ffi::Optional doc = d->GetVarDoc(n)) { return doc.value(); } - LOG(FATAL) << "IndexError: Variable is not defined in the environment: " << n; + TVM_FFI_THROW(IndexError) << "Variable is not defined in the environment: " << n; TVM_FFI_UNREACHABLE(); } @@ -99,14 +99,14 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch( // "relax", [](tvm::IRModule mod, AccessPath n_p, IRDocsifier d) -> Doc { // ffi::Optional doc = d->GetVarDoc(mod); - ICHECK(doc) << "Unable to print IRModule before definition in Relax."; + TVM_FFI_ICHECK(doc) << "Unable to print IRModule before definition in Relax."; if (d->cfg->module_alias.empty()) { // Use Module Name directly return doc.value(); } RelaxFrameNode* f = GetRelaxFrame(d); - ICHECK(f != nullptr && f->is_func) - << "IndexError: No relax environment is found when printing a module alias var " + TVM_FFI_CHECK(f != nullptr && f->is_func, IndexError) + << "No relax environment is found when printing a module alias var " "under relax's dispatch token"; if (!f->module_alias_printed) { // If the module_alias is not defined before, define it. diff --git a/src/script/printer/tir/block.cc b/src/script/printer/tir/block.cc index a5b6141dd040..d022c8b4bba4 100644 --- a/src/script/printer/tir/block.cc +++ b/src/script/printer/tir/block.cc @@ -26,7 +26,7 @@ Doc PrintBlock(IRDocsifier d, tir::SBlock block, AccessPath block_p, // ffi::Optional opt_realize, ffi::Optional opt_realize_p) { With frame(d, block); - ICHECK_EQ(opt_realize.defined(), opt_realize_p.defined()); + TVM_FFI_ICHECK_EQ(opt_realize.defined(), opt_realize_p.defined()); const tir::SBlockRealizeNode* realize = opt_realize.defined() ? opt_realize.value().get() : nullptr; AccessPath realize_p = *opt_realize_p; @@ -80,8 +80,8 @@ Doc PrintBlock(IRDocsifier d, tir::SBlock block, AccessPath block_p, // } else if (iter_var->iter_type == tir::IterVarType::kOpaque) { rhs = rhs->Attr("opaque"); } else { - LOG(FATAL) << "ValueError: Unknown IterVarType in block signature: " - << tir::IterVarType2String(iter_var->iter_type); + TVM_FFI_THROW(ValueError) << "Unknown IterVarType in block signature: " + << tir::IterVarType2String(iter_var->iter_type); } ExprDoc dom{ffi::UnsafeInit()}; if (tir::is_zero(iter_var->dom->min)) { @@ -151,7 +151,7 @@ Doc PrintBlock(IRDocsifier d, tir::SBlock block, AccessPath block_p, // // Step 2. Handle block predicate if (realize) { - ICHECK(realize->predicate.defined() && realize->predicate->dtype.is_bool()); + TVM_FFI_ICHECK(realize->predicate.defined() && realize->predicate->dtype.is_bool()); if (!tir::is_one(realize->predicate)) { (*frame)->stmts.push_back(ExprStmtDoc( TIR(d, "where") diff --git a/src/script/printer/tir/buffer.cc b/src/script/printer/tir/buffer.cc index 4057b1d09bfc..e9616d1dc6bf 100644 --- a/src/script/printer/tir/buffer.cc +++ b/src/script/printer/tir/buffer.cc @@ -54,7 +54,7 @@ ffi::Map BufferAttrs(tir::Buffer buffer, const AccessPath& return e->IsInstance() && !d->IsVarDefined(e); }; auto add_out_of_line_var_def = [&](const Var& var, const AccessPath& var_p) { - ICHECK(!d->IsVarDefined(var)); + TVM_FFI_ICHECK(!d->IsVarDefined(var)); ExprDoc lhs = DefineVar(var, frame, d); lhs->source_paths.push_back(var_p); var_def_lhs.push_back(lhs); @@ -62,7 +62,7 @@ ffi::Map BufferAttrs(tir::Buffer buffer, const AccessPath& }; auto try_inline_def = [&](const PrimExpr& e, const AccessPath& e_p, std::function inline_f) { - ICHECK(is_new_var(e)); + TVM_FFI_ICHECK(is_new_var(e)); Var var = Downcast(e); if (use_count[var.get()] == 1) { d->Define(e, frame, inline_f); @@ -320,7 +320,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) // if (ffi::Optional doc = d->GetVarDoc(buffer)) { return doc.value(); } - LOG(FATAL) << "IndexError: Buffer is not defined in the environment: " << buffer; + TVM_FFI_THROW(IndexError) << "Buffer is not defined in the environment: " << buffer; TVM_FFI_UNREACHABLE(); }); diff --git a/src/script/printer/tir/expr.cc b/src/script/printer/tir/expr.cc index da525aa35fc2..69b047b0027e 100644 --- a/src/script/printer/tir/expr.cc +++ b/src/script/printer/tir/expr.cc @@ -77,7 +77,7 @@ Doc PrintVar(const tir::Var& var, const AccessPath& var_p, const IRDocsifier& d) if (ffi::Optional doc = d->GetVarDoc(var)) { return doc.value(); } - LOG(FATAL) << "IndexError: Variable is not defined in the environment: " << var->name_hint; + TVM_FFI_THROW(IndexError) << "Variable is not defined in the environment: " << var->name_hint; TVM_FFI_UNREACHABLE(); } @@ -168,7 +168,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch( // "", [](tir::CommReducer r, AccessPath p, IRDocsifier d) -> Doc { - ICHECK_EQ(r->lhs.size(), r->rhs.size()); + TVM_FFI_ICHECK_EQ(r->lhs.size(), r->rhs.size()); ffi::Optional lambda; { With f(d, r); @@ -284,7 +284,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) } else if (call->op.as()) { prefix = d->AsDoc(call->op, call_p->Attr("op")); } else { - LOG(FATAL) << "call: " << call; + TVM_FFI_THROW(InternalError) << "call: " << call; } ffi::Array args; int n_args = call->args.size(); @@ -313,7 +313,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) return TIR(d, "reduce") ->Call({combiner}, {"source", "init", "axis", "condition", "value_index"}, {source, init, axis, condition, value_index}); - LOG(FATAL) << "ValueError: Reduce should never exist in TIR: " << r; + TVM_FFI_THROW(ValueError) << "Reduce should never exist in TIR: " << r; }); #define TVM_SCRIPT_PRINTER_DEF_BINARY(NodeType, OpString) \ diff --git a/src/script/printer/tir/for_loop.cc b/src/script/printer/tir/for_loop.cc index b2e091f38019..ec0a31e44ff2 100644 --- a/src/script/printer/tir/for_loop.cc +++ b/src/script/printer/tir/for_loop.cc @@ -34,8 +34,8 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) }; if (d->cfg->syntax_sugar) { for (const tir::ForNode* l = loop.get(); l != nullptr; l = l->body.as()) { - ICHECK(l->loop_var->dtype == l->min->dtype); - ICHECK(l->loop_var->dtype == l->extent->dtype); + TVM_FFI_ICHECK(l->loop_var->dtype == l->min->dtype); + TVM_FFI_ICHECK(l->loop_var->dtype == l->extent->dtype); if (l->kind != tir::ForKind::kSerial || // !tir::is_zero(l->min) || // !l->annotations.empty() || // @@ -98,7 +98,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) thread = LiteralDoc::Str(loop->thread_binding.value()->thread_tag, loop_p->Attr("thread_binding")); } else { - LOG(FATAL) << "ValueError: Unknown ForKind: " << tir::ForKind2String(loop->kind); + TVM_FFI_THROW(ValueError) << "Unknown ForKind: " << tir::ForKind2String(loop->kind); } ffi::Array args; ffi::Array kwargs_keys; diff --git a/src/script/printer/tir/function.cc b/src/script/printer/tir/function.cc index bfa999a5c68d..89e372aba43d 100644 --- a/src/script/printer/tir/function.cc +++ b/src/script/printer/tir/function.cc @@ -221,7 +221,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch( // "tir", [](tvm::IRModule mod, AccessPath n_p, IRDocsifier d) -> Doc { // ffi::Optional doc = d->GetVarDoc(mod); - ICHECK(doc) << "Unable to print IRModule before definition in TIR."; + TVM_FFI_ICHECK(doc) << "Unable to print IRModule before definition in TIR."; return doc.value(); }); diff --git a/src/script/printer/tir/stmt.cc b/src/script/printer/tir/stmt.cc index bf9d2253ce71..eff58cf41169 100644 --- a/src/script/printer/tir/stmt.cc +++ b/src/script/printer/tir/stmt.cc @@ -44,11 +44,11 @@ bool AllowConciseScoping(const IRDocsifier& d, const ObjectRef& obj) { return false; } } - ICHECK(!d->frames.empty()); + TVM_FFI_ICHECK(!d->frames.empty()); if (const auto* f = d->frames.back().as()) { return f->allow_concise_scoping; } - LOG(FATAL) << "NotImplementedError: fragment printing"; + TVM_FFI_THROW(NotImplementedError) << "fragment printing"; } bool IsAncestorOfAllVarUse(const tir::Stmt& node, const ObjectRef& var, const IRDocsifier& d) { diff --git a/src/script/printer/utils.h b/src/script/printer/utils.h index 8e9b9cdf1049..2ea588c5eeae 100644 --- a/src/script/printer/utils.h +++ b/src/script/printer/utils.h @@ -65,7 +65,7 @@ inline std::string Docsify(const ObjectRef& obj, const IRDocsifier& d, const Fra } move_source_paths = true; } else { - LOG(FATAL) << "TypeError: Unexpected doc type: " << doc->GetTypeKey(); + TVM_FFI_THROW(TypeError) << "Unexpected doc type: " << doc->GetTypeKey(); } std::ostringstream os; if (!d->metadata.empty()) { diff --git a/src/support/base64.h b/src/support/base64.h index afdb6509ac3a..af011e317538 100644 --- a/src/support/base64.h +++ b/src/support/base64.h @@ -157,7 +157,7 @@ class Base64InStream : public tvm::support::Stream { { // second byte temp_ch_ = reader_.GetChar(); - ICHECK(temp_ch_ != EOF && !isspace(temp_ch_)) << "invalid base64 format"; + TVM_FFI_ICHECK(temp_ch_ != EOF && !isspace(temp_ch_)) << "invalid base64 format"; nvalue |= DecodeTable[temp_ch_] << 12; *cptr++ = (nvalue >> 16) & 0xFF; --tlen; @@ -165,13 +165,13 @@ class Base64InStream : public tvm::support::Stream { { // third byte temp_ch_ = reader_.GetChar(); - ICHECK(temp_ch_ != EOF && !isspace(temp_ch_)) << "invalid base64 format"; + TVM_FFI_ICHECK(temp_ch_ != EOF && !isspace(temp_ch_)) << "invalid base64 format"; // handle termination if (temp_ch_ == '=') { temp_ch_ = reader_.GetChar(); - ICHECK(temp_ch_ == '=') << "invalid base64 format"; + TVM_FFI_ICHECK(temp_ch_ == '=') << "invalid base64 format"; temp_ch_ = reader_.GetChar(); - ICHECK(temp_ch_ == EOF || isspace(temp_ch_)) << "invalid base64 format"; + TVM_FFI_ICHECK(temp_ch_ == EOF || isspace(temp_ch_)) << "invalid base64 format"; break; } nvalue |= DecodeTable[temp_ch_] << 6; @@ -185,10 +185,10 @@ class Base64InStream : public tvm::support::Stream { { // fourth byte temp_ch_ = reader_.GetChar(); - ICHECK(temp_ch_ != EOF && !isspace(temp_ch_)) << "invalid base64 format"; + TVM_FFI_ICHECK(temp_ch_ != EOF && !isspace(temp_ch_)) << "invalid base64 format"; if (temp_ch_ == '=') { temp_ch_ = reader_.GetChar(); - ICHECK(temp_ch_ == EOF || isspace(temp_ch_)) << "invalid base64 format"; + TVM_FFI_ICHECK(temp_ch_ == EOF || isspace(temp_ch_)) << "invalid base64 format"; break; } nvalue |= DecodeTable[temp_ch_]; @@ -203,12 +203,12 @@ class Base64InStream : public tvm::support::Stream { temp_ch_ = reader_.GetChar(); } if (kStrictCheck) { - ICHECK_EQ(tlen, 0) << "Base64InStream: read incomplete"; + TVM_FFI_ICHECK_EQ(tlen, 0) << "Base64InStream: read incomplete"; } return size - tlen; } size_t Write(const void* ptr, size_t size) final { - LOG(FATAL) << "Base64InStream do not support write"; + TVM_FFI_THROW(InternalError) << "Base64InStream do not support write"; return 0; } @@ -252,7 +252,7 @@ class Base64OutStream : public tvm::support::Stream { return size; } virtual size_t Read(void* ptr, size_t size) { - LOG(FATAL) << "Base64OutStream do not support read"; + TVM_FFI_THROW(InternalError) << "Base64OutStream do not support read"; return 0; } /*! diff --git a/src/support/ffi_testing.cc b/src/support/ffi_testing.cc index 8875046874e4..0f8806a117c6 100644 --- a/src/support/ffi_testing.cc +++ b/src/support/ffi_testing.cc @@ -63,7 +63,7 @@ TVM_FFI_STATIC_INIT_BLOCK() { [](ffi::Shape shape) { return static_cast(shape.size()); }) .def("testing.GetShapeElem", [](ffi::Shape shape, int idx) { - ICHECK_LT(idx, shape.size()); + TVM_FFI_ICHECK_LT(idx, shape.size()); return shape[idx]; }) .def_packed("testing.test_wrap_callback", @@ -86,21 +86,21 @@ TVM_FFI_STATIC_INIT_BLOCK() { [](ffi::PackedArgs args, ffi::Any* ret) { auto msg = args[0].cast(); *ret = ffi::TypedFunction( - [msg](int x, int y) { CHECK_EQ(x, y) << msg; }); + [msg](int x, int y) { TVM_FFI_ICHECK_EQ(x, y) << msg; }); }) .def_packed("testing.device_test", [](ffi::PackedArgs args, ffi::Any* ret) { auto dev = args[0].cast(); int dtype = args[1].cast(); int did = args[2].cast(); - CHECK_EQ(static_cast(dev.device_type), dtype); - CHECK_EQ(static_cast(dev.device_id), did); + TVM_FFI_ICHECK_EQ(static_cast(dev.device_type), dtype); + TVM_FFI_ICHECK_EQ(static_cast(dev.device_id), did); *ret = dev; }) .def_packed("testing.identity_cpp", [](ffi::PackedArgs args, ffi::Any* ret) { const auto identity_func = tvm::ffi::Function::GetGlobal("testing.identity_py"); - ICHECK(identity_func.has_value()) - << "AttributeError: \"testing.identity_py\" is not registered. Please check " + TVM_FFI_CHECK(identity_func.has_value(), AttributeError) + << "\"testing.identity_py\" is not registered. Please check " "if the python module is properly loaded"; *ret = (*identity_func)(args[0]); }); @@ -109,10 +109,10 @@ TVM_FFI_STATIC_INIT_BLOCK() { // in src/api_test.cc void ErrorTest(int x, int y) { // raise ValueError - CHECK_EQ(x, y) << "ValueError: expect x and y to be equal."; + TVM_FFI_CHECK_EQ(x, y, ValueError) << "expect x and y to be equal."; if (x == 1) { // raise InternalError. - LOG(FATAL) << "InternalError: cannot reach here"; + TVM_FFI_THROW(InternalError) << "cannot reach here"; } } @@ -140,7 +140,7 @@ ffi::Optional FrontendTestModuleNode::GetFunction(const ffi::Stri if (name == kAddFunctionName) { return ffi::Function::FromTyped( [this, self_strong_ref](std::string func_name, ffi::Function pf) { - CHECK_NE(func_name, kAddFunctionName) + TVM_FFI_ICHECK_NE(func_name, kAddFunctionName) << "func_name: cannot be special function " << kAddFunctionName; functions_[func_name] = pf; }); @@ -195,15 +195,16 @@ TVM_FFI_STATIC_INIT_BLOCK() { .def("testing.AcceptsArrayOfPrimExpr", [](ffi::Array arr) -> ObjectRef { for (ObjectRef item : arr) { - CHECK(item->IsInstance()) << "Array contained " << item->GetTypeKey() - << " when it should contain PrimExpr"; + TVM_FFI_ICHECK(item->IsInstance()) + << "Array contained " << item->GetTypeKey() + << " when it should contain PrimExpr"; } return arr; }) .def("testing.AcceptsArrayOfVariant", [](ffi::Array> arr) -> ObjectRef { for (auto item : arr) { - CHECK(item.as() || item.as()) + TVM_FFI_ICHECK(item.as() || item.as()) << "Array should contain either PrimExpr or ffi::Function"; } return arr; @@ -211,7 +212,7 @@ TVM_FFI_STATIC_INIT_BLOCK() { .def("testing.AcceptsMapOfPrimExpr", [](ffi::Map map) -> ObjectRef { for (const auto& kv : map) { ObjectRef value = kv.second; - CHECK(value->IsInstance()) + TVM_FFI_ICHECK(value->IsInstance()) << "Map contained " << value->GetTypeKey() << " when it should contain PrimExpr"; } return map; diff --git a/src/support/nd_int_set.h b/src/support/nd_int_set.h index f63aaf92faca..01a9ee8ae398 100644 --- a/src/support/nd_int_set.h +++ b/src/support/nd_int_set.h @@ -81,7 +81,7 @@ inline NDIntSet NDIntSetFromPoint(const ffi::Array& indices) { * \param rhs The second N-dimensional integer set */ inline void NDIntSetUnionWith(NDIntSet* lhs, const NDIntSet& rhs) { - ICHECK_EQ(lhs->size(), rhs.size()); + TVM_FFI_ICHECK_EQ(lhs->size(), rhs.size()); int ndim = rhs.size(); for (int i = 0; i < ndim; ++i) { arith::IntSet& int_set = lhs->at(i); @@ -95,14 +95,14 @@ inline void NDIntSetUnionWith(NDIntSet* lhs, const NDIntSet& rhs) { * \return The result of the union */ inline NDIntSet NDIntSetUnion(const std::vector& nd_int_sets) { - ICHECK(!nd_int_sets.empty()); + TVM_FFI_ICHECK(!nd_int_sets.empty()); int n = nd_int_sets.size(); if (n == 1) { return nd_int_sets[0]; } int ndim = nd_int_sets[0].size(); for (int i = 1; i < n; ++i) { - ICHECK_EQ(nd_int_sets[i].size(), ndim); + TVM_FFI_ICHECK_EQ(nd_int_sets[i].size(), ndim); } NDIntSet result; result.reserve(ndim); diff --git a/src/support/parallel_for.cc b/src/support/parallel_for.cc index e90967562d16..46aa46b40960 100644 --- a/src/support/parallel_for.cc +++ b/src/support/parallel_for.cc @@ -34,8 +34,8 @@ namespace support { std::vector> rr_partitioner(int begin, int end, int step, int num_threads) { int total_task_count = (end - begin) / step; - ICHECK_GE(total_task_count, 0) << "Infinite loop condition with begin: " << begin - << " end: " << end << " step: " << step; + TVM_FFI_ICHECK_GE(total_task_count, 0) + << "Infinite loop condition with begin: " << begin << " end: " << end << " step: " << step; std::vector> ret; ret.reserve(num_threads); for (size_t thread = 0; begin < end; begin += step, thread = (thread + 1) % num_threads) { @@ -53,8 +53,9 @@ void parallel_for(int begin, int end, const std::function& f, int ste static std::mutex M_GLOBAL_PARALLEL_FOR_FLAG; { std::unique_lock l(M_GLOBAL_PARALLEL_FOR_FLAG); - ICHECK(!GLOBAL_PARALLEL_FOR_FLAG) << "There's another parallel_for running. Maybe you're " - << "currently inside another parallel_for loop."; + TVM_FFI_ICHECK(!GLOBAL_PARALLEL_FOR_FLAG) + << "There's another parallel_for running. Maybe you're " + << "currently inside another parallel_for loop."; GLOBAL_PARALLEL_FOR_FLAG = true; } @@ -81,7 +82,7 @@ void parallel_for(int begin, int end, const std::function& f, int ste } { std::unique_lock l(M_GLOBAL_PARALLEL_FOR_FLAG); - ICHECK(GLOBAL_PARALLEL_FOR_FLAG); + TVM_FFI_ICHECK(GLOBAL_PARALLEL_FOR_FLAG); GLOBAL_PARALLEL_FOR_FLAG = false; } try { @@ -89,7 +90,7 @@ void parallel_for(int begin, int end, const std::function& f, int ste i.get(); } } catch (const std::exception& e) { - LOG(FATAL) << "Parallel_for error with " << e.what(); + TVM_FFI_THROW(InternalError) << "Parallel_for error with " << e.what(); } } @@ -99,8 +100,8 @@ void parallel_for_dynamic(int begin, int end, int num_threads, if (begin == end) { return; } - CHECK_LE(begin, end) << "ValueError: The interval [begin, end) requires `begin <= end`"; - CHECK_GT(num_threads, 0) << "ValueError: `num_threads` should be positive"; + TVM_FFI_CHECK_LE(begin, end, ValueError) << "The interval [begin, end) requires `begin <= end`"; + TVM_FFI_CHECK_GT(num_threads, 0, ValueError) << "`num_threads` should be positive"; // Step 2. Launch threads // Step 2.1. Launch worker 1 to worker `num_threads - 1` std::atomic counter{begin}; @@ -125,7 +126,7 @@ void parallel_for_dynamic(int begin, int end, int num_threads, for (auto&& thread : threads) { thread.join(); } - LOG(FATAL) << "RuntimeError: parallel_for_dynamic error with " << e.what(); + TVM_FFI_THROW(RuntimeError) << "parallel_for_dynamic error with " << e.what(); } // Step 3. Join threads and check exceptions for (auto&& thread : threads) { @@ -136,7 +137,7 @@ void parallel_for_dynamic(int begin, int end, int num_threads, future.get(); } } catch (const std::exception& e) { - LOG(FATAL) << "RuntimeError: parallel_for_dynamic error with " << e.what(); + TVM_FFI_THROW(RuntimeError) << "parallel_for_dynamic error with " << e.what(); } } diff --git a/src/support/pipe.h b/src/support/pipe.h index 3ce60f2b7a9a..ec7a8ea14d9e 100644 --- a/src/support/pipe.h +++ b/src/support/pipe.h @@ -84,21 +84,21 @@ class Pipe : public tvm::support::Stream { return static_cast(nread); }; DWORD nread = static_cast(RetryCallOnEINTR(fread, GetLastErrorCode)); - ICHECK_EQ(static_cast(nread), size) << "Read Error: " << GetLastError(); + TVM_FFI_ICHECK_EQ(static_cast(nread), size) << "Read Error: " << GetLastError(); #else size_t nread = 0; while (size) { ssize_t nread_chunk = RetryCallOnEINTR([&]() { return read(handle_, ptr, size); }, GetLastErrorCode); - ICHECK_NE(nread_chunk, -1) << "Write Error: " << strerror(errno); + TVM_FFI_ICHECK_NE(nread_chunk, -1) << "Write Error: " << strerror(errno); if (nread_chunk == 0) { break; } - ICHECK_GE(nread_chunk, 0); - ICHECK_LE(nread_chunk, size) << "Read " << nread_chunk << " bytes, " - << "but only expected to read " << size << " bytes"; + TVM_FFI_ICHECK_GE(nread_chunk, 0); + TVM_FFI_ICHECK_LE(nread_chunk, size) << "Read " << nread_chunk << " bytes, " + << "but only expected to read " << size << " bytes"; size -= nread_chunk; ptr = static_cast(ptr) + nread_chunk; nread += nread_chunk; @@ -122,14 +122,14 @@ class Pipe : public tvm::support::Stream { return static_cast(nwrite); }; DWORD nwrite = static_cast(RetryCallOnEINTR(fwrite, GetLastErrorCode)); - ICHECK_EQ(static_cast(nwrite), size) << "Write Error: " << GetLastError(); + TVM_FFI_ICHECK_EQ(static_cast(nwrite), size) << "Write Error: " << GetLastError(); #else ssize_t nwrite = RetryCallOnEINTR([&]() { return write(handle_, ptr, size); }, GetLastErrorCode); - ICHECK_NE(nwrite, -1) << "Write Error: " << strerror(errno); + TVM_FFI_ICHECK_NE(nwrite, -1) << "Write Error: " << strerror(errno); - ICHECK_LE(nwrite, size) << "Wrote " << nwrite << " bytes, " - << "but only expected to write " << size << " bytes"; + TVM_FFI_ICHECK_LE(nwrite, size) << "Wrote " << nwrite << " bytes, " + << "but only expected to write " << size << " bytes"; #endif diff --git a/src/support/ring_buffer.h b/src/support/ring_buffer.h index 866c9c4424e0..40d741a762e4 100644 --- a/src/support/ring_buffer.h +++ b/src/support/ring_buffer.h @@ -98,7 +98,7 @@ class RingBuffer { * \param size The number of bytes to read. */ void Read(void* data, size_t size) { - ICHECK_GE(bytes_available_, size); + TVM_FFI_ICHECK_GE(bytes_available_, size); size_t ncopy = std::min(size, ring_.size() - head_ptr_); memcpy(data, &ring_[0] + head_ptr_, ncopy); if (ncopy < size) { @@ -120,7 +120,7 @@ class RingBuffer { template size_t ReadWithCallback(FSend fsend, size_t max_nbytes) { size_t size = std::min(max_nbytes, bytes_available_); - ICHECK_NE(size, 0U); + TVM_FFI_ICHECK_NE(size, 0U); size_t ncopy = std::min(size, ring_.size() - head_ptr_); size_t nsend = fsend(&ring_[0] + head_ptr_, ncopy); if (ncopy == nsend && ncopy < size) { diff --git a/src/support/scalars.cc b/src/support/scalars.cc index 692746852694..a3836a849db2 100644 --- a/src/support/scalars.cc +++ b/src/support/scalars.cc @@ -51,7 +51,8 @@ runtime::Tensor IntImmToTensor(const IntImm& int_imm) { auto* array = reinterpret_cast(data->data); array[0] = int_imm->value; } else { - LOG(FATAL) << "Unrecognized numeric literal dtype: " << DLDataTypeToString(int_imm.dtype()); + TVM_FFI_THROW(InternalError) << "Unrecognized numeric literal dtype: " + << DLDataTypeToString(int_imm.dtype()); } return data; } @@ -69,7 +70,8 @@ runtime::Tensor FloatImmToTensor(const FloatImm& float_imm) { auto* array = reinterpret_cast(data->data); array[0] = float_imm->value; } else { - LOG(FATAL) << "Unrecognized numeric literal dtype: " << DLDataTypeToString(float_imm.dtype()); + TVM_FFI_THROW(InternalError) << "Unrecognized numeric literal dtype: " + << DLDataTypeToString(float_imm.dtype()); } return data; } @@ -85,7 +87,8 @@ runtime::Tensor BoolToTensor(bool value) { std::string TensorScalarToString(const runtime::Tensor& data) { std::ostringstream os; DataType dtype(data->dtype); - ICHECK_EQ(data->device.device_type, kDLCPU) << "Scalars must reside on the CPU to be printed"; + TVM_FFI_ICHECK_EQ(data->device.device_type, kDLCPU) + << "Scalars must reside on the CPU to be printed"; if (dtype == kInt16) { auto value = static_cast(data->data)[0]; os << value << "i16"; @@ -108,7 +111,8 @@ std::string TensorScalarToString(const runtime::Tensor& data) { auto value = static_cast(data->data)[0]; os << (value ? "True" : "False"); } else { - LOG(FATAL) << "Unrecognized Tensor scalar dtype: " << DLDataTypeToString(dtype); + TVM_FFI_THROW(InternalError) << "Unrecognized Tensor scalar dtype: " + << DLDataTypeToString(dtype); } return os.str(); } @@ -124,7 +128,8 @@ std::string IntImmToString(const IntImm& int_imm) { } else if (int_imm->dtype == kBool) { os << (int_imm->value ? "True" : "False"); } else { - LOG(FATAL) << "Unrecognised IntImm dtype: " << DLDataTypeToString(int_imm->dtype); + TVM_FFI_THROW(InternalError) << "Unrecognised IntImm dtype: " + << DLDataTypeToString(int_imm->dtype); } return os.str(); } @@ -138,7 +143,8 @@ std::string FloatImmToString(const FloatImm& float_imm) { } else if (float_imm->dtype == kFloat64) { os << float_imm->value << "f64"; } else { - LOG(FATAL) << "Unrecognised FloatImm dtype: " << DLDataTypeToString(float_imm->dtype); + TVM_FFI_THROW(InternalError) << "Unrecognised FloatImm dtype: " + << DLDataTypeToString(float_imm->dtype); } return os.str(); } @@ -159,7 +165,7 @@ IntImm ValueToIntImm(int64_t value, int width) { } else if (width == 64) { return IntImm(kInt64, value); } else { - LOG(FATAL) << "Unrecognized int scalar width: " << width; + TVM_FFI_THROW(InternalError) << "Unrecognized int scalar width: " << width; } } @@ -178,7 +184,7 @@ FloatImm ValueToFloatImm(double value, int width) { } else if (width == 64) { return FloatImm(kFloat64, value); } else { - LOG(FATAL) << "Unrecognized float scalar width: " << width; + TVM_FFI_THROW(InternalError) << "Unrecognized float scalar width: " << width; } } diff --git a/src/support/socket.h b/src/support/socket.h index ed8ee4721bcb..364705449765 100644 --- a/src/support/socket.h +++ b/src/support/socket.h @@ -78,7 +78,7 @@ namespace support { inline std::string GetHostName() { std::string buf; buf.resize(256); - ICHECK_NE(gethostname(&buf[0], 256), -1); + TVM_FFI_ICHECK_NE(gethostname(&buf[0], 256), -1); return std::string(buf.c_str()); } @@ -120,7 +120,7 @@ struct SockAddr { size_t sep = url.find(","); std::string host = url.substr(2, sep - 3); std::string port = url.substr(sep + 1, url.length() - 1); - ICHECK(ValidateIP(host)) << "Url address is not valid " << url; + TVM_FFI_ICHECK(ValidateIP(host)) << "Url address is not valid " << url; if (host == "localhost") { host = "127.0.0.1"; } @@ -140,7 +140,7 @@ struct SockAddr { hints.ai_socktype = SOCK_STREAM; addrinfo* res = nullptr; int sig = getaddrinfo(host, nullptr, &hints, &res); - ICHECK(sig == 0 && res != nullptr) << "cannot obtain address of " << host; + TVM_FFI_ICHECK(sig == 0 && res != nullptr) << "cannot obtain address of " << host; switch (res->ai_family) { case AF_INET: { sockaddr_in* addr4 = reinterpret_cast(&addr); @@ -155,7 +155,7 @@ struct SockAddr { addr6->sin6_family = AF_INET6; } break; default: - ICHECK(false) << "cannot decode address"; + TVM_FFI_ICHECK(false) << "cannot decode address"; } freeaddrinfo(res); } @@ -180,7 +180,7 @@ struct SockAddr { const in_addr& addr4 = reinterpret_cast(&addr)->sin_addr; sinx_addr = reinterpret_cast(&addr4); } else { - ICHECK(false) << "illegal address"; + TVM_FFI_ICHECK(false) << "illegal address"; } #ifdef _WIN32 @@ -190,7 +190,7 @@ struct SockAddr { const char* s = inet_ntop(addr.ss_family, sinx_addr, &buf[0], static_cast(buf.length())); #endif - ICHECK(s != nullptr) << "cannot decode address"; + TVM_FFI_ICHECK(s != nullptr) << "cannot decode address"; std::ostringstream os; os << s << ":" << port(); return os.str(); @@ -339,7 +339,7 @@ class Socket { } if (LOBYTE(wsa_data.wVersion) != 2 || HIBYTE(wsa_data.wVersion) != 2) { WSACleanup(); - LOG(FATAL) << "Could not find a usable version of Winsock.dll"; + TVM_FFI_THROW(InternalError) << "Could not find a usable version of Winsock.dll"; } #endif } @@ -358,9 +358,9 @@ class Socket { static void Error(const char* msg) { int errsv = GetLastErrorCode(); #ifdef _WIN32 - LOG(FATAL) << "Socket " << msg << " Error:WSAError-code=" << errsv; + TVM_FFI_THROW(InternalError) << "Socket " << msg << " Error:WSAError-code=" << errsv; #else - LOG(FATAL) << "Socket " << msg << " Error:" << strerror(errsv); + TVM_FFI_THROW(InternalError) << "Socket " << msg << " Error:" << strerror(errsv); #endif } @@ -522,7 +522,7 @@ class TCPSocket : public Socket, public Stream { GetLastErrorCode); if (ret == -1) { if (LastErrorWouldBlock()) { - LOG(FATAL) << "would block"; + TVM_FFI_THROW(InternalError) << "would block"; } Socket::Error("RecvAll"); } @@ -538,8 +538,8 @@ class TCPSocket : public Socket, public Stream { */ void SendBytes(std::string data) { int datalen = data.length(); - ICHECK_EQ(SendAll(&datalen, sizeof(datalen)), sizeof(datalen)); - ICHECK_EQ(SendAll(data.c_str(), datalen), datalen); + TVM_FFI_ICHECK_EQ(SendAll(&datalen, sizeof(datalen)), sizeof(datalen)); + TVM_FFI_ICHECK_EQ(SendAll(data.c_str(), datalen), datalen); } /*! * \brief Receive the data to remote. @@ -547,10 +547,10 @@ class TCPSocket : public Socket, public Stream { */ std::string RecvBytes() { int datalen = 0; - ICHECK_EQ(RecvAll(&datalen, sizeof(datalen)), sizeof(datalen)); + TVM_FFI_ICHECK_EQ(RecvAll(&datalen, sizeof(datalen)), sizeof(datalen)); std::string data; data.resize(datalen); - ICHECK_EQ(RecvAll(&data[0], datalen), datalen); + TVM_FFI_ICHECK_EQ(RecvAll(&data[0], datalen), datalen); return data; } diff --git a/src/support/table_printer.h b/src/support/table_printer.h index 51c7c7007c15..078055aa95d9 100644 --- a/src/support/table_printer.h +++ b/src/support/table_printer.h @@ -128,7 +128,7 @@ inline std::string TablePrinter::AsStr() const { column_width[i] = std::max(column_width[i], row[i].size()); } } - ICHECK(!column_width.empty()); + TVM_FFI_ICHECK(!column_width.empty()); size_t total_width = std::accumulate(column_width.begin(), column_width.end(), 0) + 3 * column_width.size() - 1; bool is_first = true; diff --git a/src/target/build_common.h b/src/target/build_common.h index e69e992af0a2..b1192eeca8e0 100644 --- a/src/target/build_common.h +++ b/src/target/build_common.h @@ -43,7 +43,8 @@ inline ffi::Map ExtractFuncInfo(const IRModu ffi::Map fmap; for (auto kv : mod->functions) { - ICHECK(kv.second->IsInstance()) << "Can only lower IR Module with PrimFuncs"; + TVM_FFI_ICHECK(kv.second->IsInstance()) + << "Can only lower IR Module with PrimFuncs"; auto f = Downcast(kv.second); ffi::Array arg_types; diff --git a/src/target/canonicalizer/llvm/arm_aprofile.cc b/src/target/canonicalizer/llvm/arm_aprofile.cc index 44aab4035622..97f0071394a9 100644 --- a/src/target/canonicalizer/llvm/arm_aprofile.cc +++ b/src/target/canonicalizer/llvm/arm_aprofile.cc @@ -86,7 +86,7 @@ bool CheckContains(ffi::Array array, ffi::String predicate) { static ffi::Map GetFeatures(ffi::Map target) { #ifdef TVM_LLVM_VERSION ffi::String kind = Downcast(target.Get("kind").value()); - ICHECK_EQ(kind, "llvm") << "Expected target kind 'llvm', but got '" << kind << "'"; + TVM_FFI_ICHECK_EQ(kind, "llvm") << "Expected target kind 'llvm', but got '" << kind << "'"; ffi::Optional mtriple = Downcast>(target.Get("mtriple").value_or(nullptr)); diff --git a/src/target/canonicalizer/llvm/canonicalize.cc b/src/target/canonicalizer/llvm/canonicalize.cc index 7ddd5420c0a5..732528021d74 100644 --- a/src/target/canonicalizer/llvm/canonicalize.cc +++ b/src/target/canonicalizer/llvm/canonicalize.cc @@ -31,8 +31,8 @@ namespace llvm { ffi::Optional DetectSystemTriple() { #ifdef TVM_LLVM_VERSION auto pf = tvm::ffi::Function::GetGlobal("target.llvm_get_system_triple"); - ICHECK(pf.has_value()) << "The target llvm_get_system_triple was not found, " - "please compile with USE_LLVM = ON"; + TVM_FFI_ICHECK(pf.has_value()) << "The target llvm_get_system_triple was not found, " + "please compile with USE_LLVM = ON"; return (*pf)().cast(); #endif return {}; diff --git a/src/target/codegen.cc b/src/target/codegen.cc index 61a455f5305a..b1bb9ae3e2dc 100644 --- a/src/target/codegen.cc +++ b/src/target/codegen.cc @@ -54,7 +54,7 @@ ffi::Module Build(IRModule mod, Target target) { // the build function. std::string build_f_name = "target.build." + target->kind->name; const auto bf = tvm::ffi::Function::GetGlobal(build_f_name); - ICHECK(bf.has_value()) << build_f_name << " is not enabled"; + TVM_FFI_ICHECK(bf.has_value()) << build_f_name << " is not enabled"; return (*bf)(mod, target).cast(); } @@ -68,23 +68,23 @@ class ModuleSerializer { stream->Write(import_tree_row_ptr_); stream->Write(import_tree_child_indices_); for (const auto& group : mod_group_vec_) { - ICHECK_NE(group.size(), 0) << "Every allocated group must have at least one module"; + TVM_FFI_ICHECK_NE(group.size(), 0) << "Every allocated group must have at least one module"; // we prioritize export dso when a module is both serializable and exportable if (export_dso) { if (group[0]->GetPropertyMask() & ffi::Module::kCompilationExportable) { std::string mod_type_key = "_lib"; stream->Write(mod_type_key); } else if (group[0]->GetPropertyMask() & ffi::Module::kBinarySerializable) { - ICHECK_EQ(group.size(), 1U) << "Non DSO module is never merged"; + TVM_FFI_ICHECK_EQ(group.size(), 1U) << "Non DSO module is never merged"; std::string mod_type_key = group[0]->kind(); stream->Write(mod_type_key); std::string bytes = group[0]->SaveToBytes(); stream->Write(bytes); } } else { - ICHECK(group[0]->GetPropertyMask() & ffi::Module::kBinarySerializable) + TVM_FFI_ICHECK(group[0]->GetPropertyMask() & ffi::Module::kBinarySerializable) << group[0]->kind() << " is not binary serializable."; - ICHECK_EQ(group.size(), 1U) << "Non DSO module is never merged"; + TVM_FFI_ICHECK_EQ(group.size(), 1U) << "Non DSO module is never merged"; std::string mod_type_key = group[0]->kind(); stream->Write(mod_type_key); std::string bytes = group[0]->SaveToBytes(); @@ -192,8 +192,8 @@ class ModuleSerializer { // Check cycles due to merging dso exportable modules. if (child_indices.size() != 0) { // The index is supposed to follow the topological order. - CHECK_LT(parent_index, child_indices[0]) - << "RuntimeError: Cannot export due to multiple dso-exportables " + TVM_FFI_CHECK_LT(parent_index, child_indices[0], RuntimeError) + << "Cannot export due to multiple dso-exportables " << "that cannot be merged without creating a cycle in the import tree. " << "Related module keys: parent=" << mod_group_vec_[parent_index][0]->kind() << ", child=" << mod_group_vec_[child_indices[0]][0]->kind(); @@ -236,18 +236,19 @@ ffi::Module DeserializeModuleFromBytes(std::string blob) { uint64_t size = import_tree_row_ptr.size() - 1; for (uint64_t i = 0; i < size; ++i) { std::string tkey; - ICHECK(stream.Read(&tkey)); + TVM_FFI_ICHECK(stream.Read(&tkey)); // "_lib" serves as a placeholder in the module import tree to indicate where // to place the DSOModule - ICHECK(tkey != "_lib") << "Should not contain any placeholder for DSOModule."; + TVM_FFI_ICHECK(tkey != "_lib") << "Should not contain any placeholder for DSOModule."; if (tkey == "_import_tree") { - ICHECK(stream.Read(&import_tree_row_ptr)); - ICHECK(stream.Read(&import_tree_child_indices)); + TVM_FFI_ICHECK(stream.Read(&import_tree_row_ptr)); + TVM_FFI_ICHECK(stream.Read(&import_tree_child_indices)); } else { std::string bytes; - ICHECK(stream.Read(&bytes)); + TVM_FFI_ICHECK(stream.Read(&bytes)); auto loader = ffi::Function::GetGlobal("ffi.Module.load_from_bytes." + tkey); - ICHECK(loader.has_value()) << "ffi.Module.load_from_bytes." << tkey << " is not enabled"; + TVM_FFI_ICHECK(loader.has_value()) + << "ffi.Module.load_from_bytes." << tkey << " is not enabled"; auto m = (*loader)(ffi::Bytes(bytes)).cast(); modules.emplace_back(m); } @@ -256,12 +257,12 @@ ffi::Module DeserializeModuleFromBytes(std::string blob) { for (size_t i = 0; i < modules.size(); ++i) { for (size_t j = import_tree_row_ptr[i]; j < import_tree_row_ptr[i + 1]; ++j) { auto child_index = import_tree_child_indices[j]; - ICHECK(child_index < modules.size()); + TVM_FFI_ICHECK(child_index < modules.size()); modules[i]->ImportModule(modules[child_index]); } } - ICHECK(!modules.empty()) << "modules cannot be empty when import tree is present"; + TVM_FFI_ICHECK(!modules.empty()) << "modules cannot be empty when import tree is present"; // invariance: root module is always at location 0. // The module order is collected via DFS ffi::Module root_mod = modules[0]; @@ -282,7 +283,7 @@ std::string PackImportsToBytes(const ffi::Module& mod) { std::string PackImportsToC(const ffi::Module& mod, bool system_lib, const std::string& c_symbol_prefix) { if (c_symbol_prefix.length() != 0) { - CHECK(system_lib) + TVM_FFI_ICHECK(system_lib) << "c_symbol_prefix advanced option should be used in conjuction with system-lib"; } @@ -327,7 +328,7 @@ ffi::Module PackImportsToLLVM(const ffi::Module& mod, bool system_lib, const std::string& llvm_target_string, const std::string& c_symbol_prefix) { if (c_symbol_prefix.length() != 0) { - CHECK(system_lib) + TVM_FFI_ICHECK(system_lib) << "c_symbol_prefix advanced option should be used in conjuction with system-lib"; } @@ -337,7 +338,7 @@ ffi::Module PackImportsToLLVM(const ffi::Module& mod, bool system_lib, std::string codegen_f_name = "codegen.codegen_blob"; // the codegen function. const auto codegen_f = tvm::ffi::Function::GetGlobal(codegen_f_name); - ICHECK(codegen_f.has_value()) << "codegen.codegen_blob is not presented."; + TVM_FFI_ICHECK(codegen_f.has_value()) << "codegen.codegen_blob is not presented."; return (*codegen_f)(ffi::Bytes(blob), system_lib, llvm_target_string, c_symbol_prefix) .cast(); } diff --git a/src/target/datatype/registry.cc b/src/target/datatype/registry.cc index 9f534e8d69b4..9d6459df6cce 100644 --- a/src/target/datatype/registry.cc +++ b/src/target/datatype/registry.cc @@ -55,20 +55,20 @@ Registry* Registry::Global() { } void Registry::Register(const std::string& type_name, uint8_t type_code) { - ICHECK(type_code >= DataType::kCustomBegin) + TVM_FFI_ICHECK(type_code >= DataType::kCustomBegin) << "Please choose a type code >= DataType::kCustomBegin for custom types"; code_to_name_[type_code] = type_name; name_to_code_[type_name] = type_code; } uint8_t Registry::GetTypeCode(const std::string& type_name) { - ICHECK(name_to_code_.find(type_name) != name_to_code_.end()) + TVM_FFI_ICHECK(name_to_code_.find(type_name) != name_to_code_.end()) << "Type name " << type_name << " not registered"; return name_to_code_[type_name]; } std::string Registry::GetTypeName(uint8_t type_code) { - ICHECK(code_to_name_.find(type_code) != code_to_name_.end()) + TVM_FFI_ICHECK(code_to_name_.find(type_code) != code_to_name_.end()) << "Type code " << static_cast(type_code) << " not registered"; return code_to_name_[type_code]; } diff --git a/src/target/intrin_rule.cc b/src/target/intrin_rule.cc index 259eacc53812..91701b067b47 100644 --- a/src/target/intrin_rule.cc +++ b/src/target/intrin_rule.cc @@ -125,12 +125,12 @@ TVM_REGISTER_OP("tir.pow").set_attr("default.FLowerIntrinsic", TVM_REGISTER_OP("tir.tvm_access_ptr") .set_attr("default.FLowerIntrinsic", [](const PrimExpr& e) -> PrimExpr { const CallNode* call = e.as(); - ICHECK(call != nullptr); - ICHECK_EQ(call->args.size(), 5U); + TVM_FFI_ICHECK(call != nullptr); + TVM_FFI_ICHECK_EQ(call->args.size(), 5U); DataType dtype = call->args[0].dtype(); Var buffer_var = Downcast(call->args[1]); PrimExpr offset = call->args[2]; - ICHECK(call->dtype.is_handle()); + TVM_FFI_ICHECK(call->dtype.is_handle()); if (dtype.lanes() != 1) { offset = offset * make_const(offset.dtype(), dtype.lanes()); offset = Ramp(offset, make_const(offset.dtype(), 1), dtype.lanes()); @@ -144,15 +144,15 @@ TVM_REGISTER_OP("tir.tvm_access_ptr") PrimExpr DispatchFastErf(const PrimExpr& e) { DLOG(WARNING) << "fast_erf will be used instead of erf"; const CallNode* call = e.as(); - ICHECK(call != nullptr); - ICHECK_EQ(call->args.size(), 1); + TVM_FFI_ICHECK(call != nullptr); + TVM_FFI_ICHECK_EQ(call->args.size(), 1); PrimExpr arg = call->args[0]; int bits = arg.dtype().bits(); PrimExpr res; if (arg.dtype().is_float() && (bits == 16 || bits == 32)) { res = fast_erf_float_expr(arg, bits); } else { - LOG(FATAL) << "Unsupported type in Metal fast_erf"; + TVM_FFI_THROW(InternalError) << "Unsupported type in Metal fast_erf"; } return res; } @@ -161,7 +161,7 @@ PrimExpr DispatchNumericalStableTanh(const PrimExpr& e) { using tir::make_const; using tir::make_zero; const tir::CallNode* call = e.as(); - ICHECK(call != nullptr); + TVM_FFI_ICHECK(call != nullptr); const PrimExpr& x = call->args[0]; PrimExpr one = make_const(x.dtype(), 1); PrimExpr two = make_const(x.dtype(), 2); @@ -184,7 +184,7 @@ using namespace tir; TVM_REGISTER_OP("tir.rsqrt") .set_attr("default.FLegalize", [](const PrimExpr& e) -> PrimExpr { const CallNode* call = e.as(); - ICHECK(call != nullptr); + TVM_FFI_ICHECK(call != nullptr); auto one = make_const(call->args[0].dtype(), 1); return one / sqrt(call->args[0]); }); @@ -192,7 +192,7 @@ TVM_REGISTER_OP("tir.rsqrt") TVM_REGISTER_OP("tir.sigmoid") .set_attr("default.FLegalize", [](const PrimExpr& e) -> PrimExpr { const CallNode* call = e.as(); - ICHECK(call != nullptr); + TVM_FFI_ICHECK(call != nullptr); auto one = make_const(call->args[0].dtype(), 1); return one / (one + exp(-call->args[0])); }); @@ -200,14 +200,14 @@ TVM_REGISTER_OP("tir.sigmoid") TVM_REGISTER_OP("tir.isfinite") .set_attr("default.FLegalize", [](const PrimExpr& e) -> PrimExpr { const CallNode* call = e.as(); - ICHECK(call != nullptr); + TVM_FFI_ICHECK(call != nullptr); return isfinite(call->args[0]); }); TVM_REGISTER_OP("tir.isinf") .set_attr("default.FLegalize", [](const PrimExpr& e) -> PrimExpr { const CallNode* call = e.as(); - ICHECK(call != nullptr); + TVM_FFI_ICHECK(call != nullptr); return isinf(call->args[0]); }); @@ -223,9 +223,11 @@ TVM_REGISTER_OP("tir.isinf") static PrimExpr QMultiplyShift(PrimExpr x, PrimExpr y, PrimExpr q, PrimExpr left_shift, PrimExpr right_shift, PrimExpr is_left_shift_required) { // Only int32 types are supported (any number of lanes is allowed) - ICHECK(y.dtype().code() == DLDataTypeCode::kDLInt && y.dtype().bits() == 32); - ICHECK(left_shift.dtype().code() == DLDataTypeCode::kDLInt && left_shift.dtype().bits() == 32); - ICHECK(right_shift.dtype().code() == DLDataTypeCode::kDLInt && right_shift.dtype().bits() == 32); + TVM_FFI_ICHECK(y.dtype().code() == DLDataTypeCode::kDLInt && y.dtype().bits() == 32); + TVM_FFI_ICHECK(left_shift.dtype().code() == DLDataTypeCode::kDLInt && + left_shift.dtype().bits() == 32); + TVM_FFI_ICHECK(right_shift.dtype().code() == DLDataTypeCode::kDLInt && + right_shift.dtype().bits() == 32); DataType hp_dtype = DataType::Int(64, x.dtype().lanes()); DataType lp_dtype = DataType::Int(32, x.dtype().lanes()); @@ -256,7 +258,7 @@ TVM_REGISTER_OP("tir.q_multiply_shift") using tir::make_const; const tir::CallNode* call = e.as(); - ICHECK(call != nullptr); + TVM_FFI_ICHECK(call != nullptr); PrimExpr x = call->args[0]; PrimExpr y = call->args[1]; @@ -269,9 +271,9 @@ TVM_REGISTER_OP("tir.q_multiply_shift") return int_node->value; } auto broadcast_node = node.as(); - CHECK(broadcast_node != nullptr); + TVM_FFI_ICHECK(broadcast_node != nullptr); auto int_node = broadcast_node->value.as(); - CHECK(int_node != nullptr); + TVM_FFI_ICHECK(int_node != nullptr); return int_node->value; }; // Power of 2 is determined by the fixed_point_multiplier == 1 << 30. In case of power of @@ -294,7 +296,7 @@ TVM_REGISTER_OP("tir.q_multiply_shift") } } else { // Only int32 types are supported (any number of lanes is allowed) - ICHECK(s.dtype().code() == DLDataTypeCode::kDLInt && s.dtype().bits() == 32); + TVM_FFI_ICHECK(s.dtype().code() == DLDataTypeCode::kDLInt && s.dtype().bits() == 32); // Calculating integer shifts PrimExpr zero = make_const(s.dtype(), 0); @@ -309,7 +311,7 @@ TVM_REGISTER_OP("tir.q_multiply_shift") TVM_REGISTER_OP("tir.q_multiply_shift_per_axis") .set_attr("default.FLegalize", [](const PrimExpr& e) -> PrimExpr { const tir::CallNode* call = e.as(); - ICHECK(call != nullptr); + TVM_FFI_ICHECK(call != nullptr); PrimExpr x = call->args[0]; PrimExpr y = call->args[1]; diff --git a/src/target/intrin_rule.h b/src/target/intrin_rule.h index 5b6b0e107c02..3f5ac43211ea 100644 --- a/src/target/intrin_rule.h +++ b/src/target/intrin_rule.h @@ -62,16 +62,16 @@ struct Direct { template inline PrimExpr DispatchPureExtern(const PrimExpr& e) { const CallNode* call = e.as(); - ICHECK(call != nullptr); + TVM_FFI_ICHECK(call != nullptr); // Use string based dispatch to extern for backward compact // TODO(tvm-team) replace once the new dispatching system is inplace. const OpNode* op = call->op.as(); - ICHECK(op != nullptr); + TVM_FFI_ICHECK(op != nullptr); std::string name = op->name; - ICHECK_EQ(name.substr(0, 4), "tir."); + TVM_FFI_ICHECK_EQ(name.substr(0, 4), "tir."); DataType dtype; if (dtype_from_arg) { - ICHECK_EQ(call->args.size(), 1U); + TVM_FFI_ICHECK_EQ(call->args.size(), 1U); dtype = call->args[0].dtype(); } else { dtype = call->dtype; diff --git a/src/target/llvm/codegen_aarch64.cc b/src/target/llvm/codegen_aarch64.cc index 872e4f4cd110..22e7d873b646 100644 --- a/src/target/llvm/codegen_aarch64.cc +++ b/src/target/llvm/codegen_aarch64.cc @@ -87,20 +87,20 @@ void CodeGenAArch64::VisitStmt_(const AttrStmtNode* op) { } const auto* attr_value = op->value.as(); - ICHECK(attr_value) << "Expect " << attr_key << " to have a ffi::String value but was " - << op->value->GetTypeKey(); + TVM_FFI_ICHECK(attr_value) << "Expect " << attr_key << " to have a ffi::String value but was " + << op->value->GetTypeKey(); std::string aarch64_attr_key = attr_key.substr(7); if (aarch64_attr_key == "aarch64_pstate_sm") { - ICHECK(!func_has_pstate_sm) << "Multiple definitions of " << op->attr_key - << " attribute found in the function " - << function_->getName().data(); + TVM_FFI_ICHECK(!func_has_pstate_sm) + << "Multiple definitions of " << op->attr_key << " attribute found in the function " + << function_->getName().data(); function_->addFnAttr({aarch64_attr_key + "_" + attr_value->value}); func_has_pstate_sm = true; } else if (aarch64_attr_key == "aarch64_pstate_za") { - ICHECK(!func_has_pstate_za) << "Multiple definitions of " << op->attr_key - << " attribute found in the function " - << function_->getName().data(); + TVM_FFI_ICHECK(!func_has_pstate_za) + << "Multiple definitions of " << op->attr_key << " attribute found in the function " + << function_->getName().data(); function_->addFnAttr({aarch64_attr_key + "_" + attr_value->value}); func_has_pstate_za = true; } else { diff --git a/src/target/llvm/codegen_amdgpu.cc b/src/target/llvm/codegen_amdgpu.cc index 034b982f64b3..ac250982ce72 100644 --- a/src/target/llvm/codegen_amdgpu.cc +++ b/src/target/llvm/codegen_amdgpu.cc @@ -99,7 +99,7 @@ class CodeGenAMDGPU : public CodeGenLLVM { } void VisitStmt_(const AllocateNode* op) final { - ICHECK(!is_zero(op->condition)); + TVM_FFI_ICHECK(!is_zero(op->condition)); llvm::Value* buf = nullptr; StorageInfo& info = alloc_storage_info_[op->buffer_var.get()]; auto storage_scope = runtime::StorageScope::Create(GetPtrStorageScope(op->buffer_var)); @@ -110,7 +110,8 @@ class CodeGenAMDGPU : public CodeGenLLVM { llvm::GlobalValue::ExternalLinkage); } else { size_t constant_size = op->ConstantAllocationSize(); - ICHECK_GT(constant_size, 0) << "Can only handle constant size stack allocation in GPU"; + TVM_FFI_ICHECK_GT(constant_size, 0) + << "Can only handle constant size stack allocation in GPU"; if (constant_size % 4 == 0 && info.alignment == 0) { info.alignment = GetTempAllocaAlignment(op->dtype, constant_size); @@ -139,7 +140,7 @@ class CodeGenAMDGPU : public CodeGenLLVM { } buf = alloca; } else { - ICHECK(storage_scope.rank == runtime::StorageRank::kShared) + TVM_FFI_ICHECK(storage_scope.rank == runtime::StorageRank::kShared) << "Can only allocate shared or local memory inside kernel"; // Shared memory: address space == 3 buf = AllocateSharedMemory(op->dtype, constant_size, 3, info.alignment, @@ -150,7 +151,7 @@ class CodeGenAMDGPU : public CodeGenLLVM { buf = builder_->CreatePointerCast( buf, llvmGetPointerTo(DTypeToLLVMType(op->dtype), buf->getType()->getPointerAddressSpace())); - ICHECK(!var_map_.count(op->buffer_var.get())); + TVM_FFI_ICHECK(!var_map_.count(op->buffer_var.get())); var_map_[op->buffer_var.get()] = buf; this->VisitStmt(op->body); } @@ -171,10 +172,10 @@ class CodeGenAMDGPU : public CodeGenLLVM { intrin_id = llvm::Intrinsic::amdgcn_workitem_id_z; break; default: - LOG(FATAL) << "unknown workitem idx"; + TVM_FFI_THROW(InternalError) << "unknown workitem idx"; } } else { - ICHECK_EQ(ts.rank, 0); + TVM_FFI_ICHECK_EQ(ts.rank, 0); switch (ts.dim_index) { case 0: intrin_id = llvm::Intrinsic::amdgcn_workgroup_id_x; @@ -186,7 +187,7 @@ class CodeGenAMDGPU : public CodeGenLLVM { intrin_id = llvm::Intrinsic::amdgcn_workgroup_id_z; break; default: - LOG(FATAL) << "unknown workgroup idx"; + TVM_FFI_THROW(InternalError) << "unknown workgroup idx"; } } #if TVM_LLVM_VERSION >= 200 @@ -213,7 +214,7 @@ class CodeGenAMDGPU : public CodeGenLLVM { #endif return builder_->CreateCall(f, {}); } else { - LOG(FATAL) << "Do not support sync " << sync; + TVM_FFI_THROW(InternalError) << "Do not support sync " << sync; } } @@ -228,7 +229,7 @@ class CodeGenAMDGPU : public CodeGenLLVM { llvm::Value* CreateIntrinsic(const CallNode* op) final { if (op->op.same_as(builtin::atomic_add())) { - ICHECK(op->args[1]->dtype.bits() == 32) << "Only supports 32 bit atomic for now"; + TVM_FFI_ICHECK(op->args[1]->dtype.bits() == 32) << "Only supports 32 bit atomic for now"; llvm::Value* v0 = MakeValue(op->args[0]); llvm::Value* v1 = MakeValue(op->args[1]); if (op->args[1]->dtype.is_float()) { @@ -241,7 +242,7 @@ class CodeGenAMDGPU : public CodeGenLLVM { llvm::AtomicOrdering::Monotonic); #endif #else - LOG(FATAL) << "Floating point atomic requires LLVM 9 or newer"; + TVM_FFI_THROW(InternalError) << "Floating point atomic requires LLVM 9 or newer"; #endif } #if TVM_LLVM_VERSION >= 130 @@ -268,7 +269,7 @@ ffi::Module BuildAMDGPU(IRModule mod, Target target) { With llvm_target(llvm_instance, target); #if TVM_LLVM_VERSION < 90 - LOG(FATAL) << "AMDGPU backend requires at least LLVM 9"; + TVM_FFI_THROW(InternalError) << "AMDGPU backend requires at least LLVM 9"; // Lower versions will crash when loading the bitcode, see // issue #4087 for a discussion #endif @@ -314,16 +315,18 @@ ffi::Module BuildAMDGPU(IRModule mod, Target target) { llvm::legacy::PassManager pass; #if TVM_LLVM_VERSION <= 60 - ICHECK(tm->addPassesToEmitFile(pass, destObj, llvm::TargetMachine::CGFT_ObjectFile) == 0) + TVM_FFI_ICHECK(tm->addPassesToEmitFile(pass, destObj, llvm::TargetMachine::CGFT_ObjectFile) == 0) << "Cannot emit target CGFT_ObjectFile"; #elif TVM_LLVM_VERSION <= 90 - ICHECK(tm->addPassesToEmitFile(pass, destObj, nullptr, llvm::TargetMachine::CGFT_ObjectFile) == 0) + TVM_FFI_ICHECK( + tm->addPassesToEmitFile(pass, destObj, nullptr, llvm::TargetMachine::CGFT_ObjectFile) == 0) << "Cannot emit target CGFT_ObjectFile"; #elif TVM_LLVM_VERSION <= 170 - ICHECK(tm->addPassesToEmitFile(pass, destObj, nullptr, llvm::CGFT_ObjectFile) == 0) + TVM_FFI_ICHECK(tm->addPassesToEmitFile(pass, destObj, nullptr, llvm::CGFT_ObjectFile) == 0) << "Cannot emit target CGFT_ObjectFile"; #else - ICHECK(tm->addPassesToEmitFile(pass, destObj, nullptr, llvm::CodeGenFileType::ObjectFile) == 0) + TVM_FFI_ICHECK( + tm->addPassesToEmitFile(pass, destObj, nullptr, llvm::CodeGenFileType::ObjectFile) == 0) << "Cannot emit target CodeGenFileType::ObjectFile"; #endif pass.run(*mObj); @@ -331,25 +334,26 @@ ffi::Module BuildAMDGPU(IRModule mod, Target target) { llvm::legacy::PassManager passAsm; #if TVM_LLVM_VERSION <= 60 - ICHECK(tm->addPassesToEmitFile(passAsm, destAsm, llvm::TargetMachine::CGFT_AssemblyFile) == 0) + TVM_FFI_ICHECK( + tm->addPassesToEmitFile(passAsm, destAsm, llvm::TargetMachine::CGFT_AssemblyFile) == 0) << "Cannot emit target CGFT_AssemblyFile"; #elif TVM_LLVM_VERSION <= 90 - ICHECK(tm->addPassesToEmitFile(passAsm, destAsm, nullptr, - llvm::TargetMachine::CGFT_AssemblyFile) == 0) + TVM_FFI_ICHECK(tm->addPassesToEmitFile(passAsm, destAsm, nullptr, + llvm::TargetMachine::CGFT_AssemblyFile) == 0) << "Cannot emit target CGFT_AssemblyFile"; #elif TVM_LLVM_VERSION <= 170 - ICHECK(tm->addPassesToEmitFile(passAsm, destAsm, nullptr, llvm::CGFT_AssemblyFile) == 0) + TVM_FFI_ICHECK(tm->addPassesToEmitFile(passAsm, destAsm, nullptr, llvm::CGFT_AssemblyFile) == 0) << "Cannot emit target CGFT_AssemblyFile"; #else - ICHECK(tm->addPassesToEmitFile(passAsm, destAsm, nullptr, llvm::CodeGenFileType::AssemblyFile) == - 0) + TVM_FFI_ICHECK( + tm->addPassesToEmitFile(passAsm, destAsm, nullptr, llvm::CodeGenFileType::AssemblyFile) == 0) << "Cannot emit target CGFT_AssemblyFile"; #endif passAsm.run(*mAsm); std::string assembly(dataAsm.begin(), dataAsm.end()); auto flink = tvm::ffi::Function::GetGlobal("tvm_callback_rocm_link"); - ICHECK(flink.has_value()) + TVM_FFI_ICHECK(flink.has_value()) << "Require tvm_callback_rocm_link to exist, do import tvm.contrib.rocm"; TVMFFIByteArray arr; diff --git a/src/target/llvm/codegen_arm.cc b/src/target/llvm/codegen_arm.cc index b1888a4928ab..c742b2f75fee 100644 --- a/src/target/llvm/codegen_arm.cc +++ b/src/target/llvm/codegen_arm.cc @@ -97,7 +97,7 @@ PrimExpr CodeGenARM::ARMPopcount(const CallNode* call) { PrimExpr input8 = reinterpret(uint8_type, e); // Popcount 8bit->8bit const CallNode* c0 = input8.as(); - ICHECK(c0 != nullptr); + TVM_FFI_ICHECK(c0 != nullptr); ffi::Array vcnt8_args; vcnt8_args.push_back(IntImm(DataType::UInt(32), ctpop_id)); vcnt8_args.push_back(input8); diff --git a/src/target/llvm/codegen_cpu.cc b/src/target/llvm/codegen_cpu.cc index bc67cdad2fd3..906d79b176ce 100644 --- a/src/target/llvm/codegen_cpu.cc +++ b/src/target/llvm/codegen_cpu.cc @@ -236,7 +236,7 @@ void CodeGenCPU::AddMainFunction(const std::string& entry_func_name) { } // create a wrapper function with tvm_ffi_main name and redirects to the entry function llvm::Function* target_func = module_->getFunction(entry_func_name); - ICHECK(target_func) << "Function " << entry_func_name << " does not exist in module"; + TVM_FFI_ICHECK(target_func) << "Function " << entry_func_name << " does not exist in module"; // Create wrapper function llvm::Function* wrapper_func = @@ -286,7 +286,7 @@ CodeGenLLVM::TypedPointer CodeGenCPU::CreateStructRefPtr(DataType t, llvm::Value if (buf->getType() == t_void_p_) { buf = builder_->CreatePointerCast(buf, llvmGetPointerTo(t_tvm_array_, 0)); } else { - ICHECK_EQ(buf->getType(), llvmGetPointerTo(t_tvm_array_, 0)); + TVM_FFI_ICHECK_EQ(buf->getType(), llvmGetPointerTo(t_tvm_array_, 0)); } } switch (kind) { @@ -364,7 +364,7 @@ CodeGenLLVM::TypedPointer CodeGenCPU::CreateStructRefPtr(DataType t, llvm::Value return TypedPointer(t_int32_, buf); } case builtin::kTVMFFIAnyUnionValue: { - ICHECK_EQ(t.lanes(), 1); + TVM_FFI_ICHECK_EQ(t.lanes(), 1); buf = builder_->CreatePointerCast(buf, llvmGetPointerTo(t_tvm_ffi_any_, 0)); // field 2 is the union value buf = builder_->CreateInBoundsGEP(t_tvm_ffi_any_, buf, {index, ConstInt32(2)}); @@ -386,7 +386,7 @@ CodeGenLLVM::TypedPointer CodeGenCPU::CreateStructRefPtr(DataType t, llvm::Value } } default: - LOG(FATAL) << "unknown field code"; + TVM_FFI_THROW(InternalError) << "unknown field code"; } } @@ -444,7 +444,7 @@ llvm::GlobalVariable* CodeGenCPU::InitContextPtr(llvm::Type* p_type, std::string } llvm::Value* CodeGenCPU::GetContextPtr(llvm::GlobalVariable* gv) { - ICHECK(gv != nullptr); + TVM_FFI_ICHECK(gv != nullptr); #if TVM_LLVM_VERSION >= 110 llvm::LoadInst* faddr = builder_->CreateAlignedLoad(gv->getValueType(), gv, llvm::Align(gv->getAlignment())); @@ -548,7 +548,7 @@ void CodeGenCPU::CreateComputeScope(const AttrStmtNode* op) { // to call them correctly on MIPS platform (CALL16 issue) // Linkage ld Error: CALL16 reloc at 0x290 not against global symbol const StringImmNode* value = op->value.as(); - ICHECK(value != nullptr); + TVM_FFI_ICHECK(value != nullptr); llvm::Function* fcompute = llvm::Function::Create(ftype, llvm::Function::InternalLinkage, MakeStringRef(value->value), module_.get()); SetTargetAttributes(fcompute); @@ -611,7 +611,7 @@ CodeGenLLVM::TypedPointer CodeGenCPU::PackClosureData(const ffi::Array& vfi std::vector fields; for (Var v : vfields) { auto it = var_map_.find(v.get()); - ICHECK(it != var_map_.end()); + TVM_FFI_ICHECK(it != var_map_.end()); fields.push_back(it->second->getType()); } llvm::StructType* ctype = struct_name.size() ? llvm::StructType::create(fields, struct_name) @@ -692,7 +692,8 @@ void CodeGenCPU::CreateParallelLaunch(const Stmt& body, int num_task, std::strin std::swap(analyzer_, new_analyzer); std::swap(parallel_env_, par_env); std::swap(function_, f); - ICHECK_NE(par_env.parallel_loop_count, 0) << "Cannot find parallel loop within parallel launch"; + TVM_FFI_ICHECK_NE(par_env.parallel_loop_count, 0) + << "Cannot find parallel loop within parallel launch"; builder_->SetInsertPoint(par_launch_end); } @@ -735,7 +736,7 @@ void CodeGenCPU::CreateStaticInit(const std::string& init_fname, const Stmt& bod // setup new variable map, swap it with current var context. std::unordered_map new_vmap; UnpackClosureData(cdata, vfields, &new_vmap); - ICHECK(parallel_env_.penv == nullptr); + TVM_FFI_ICHECK(parallel_env_.penv == nullptr); auto new_analyzer = std::make_unique(); std::swap(function_, f); std::swap(analyzer_, new_analyzer); @@ -839,14 +840,14 @@ CodeGenCPU::PackedCall CodeGenCPU::MakeCallPackedLowered(const ffi::Array(); - ICHECK(ptr) << "Expected first argument of tir::Call to be " - << "a string containing the callee's name, " - << "but instead contained " << args[0]; + TVM_FFI_ICHECK(ptr) << "Expected first argument of tir::Call to be " + << "a string containing the callee's name, " + << "but instead contained " << args[0]; return ptr->value; }(); // call the function int64_t nargs = end - begin; - ICHECK_GE(nargs, 0); + TVM_FFI_ICHECK_GE(nargs, 0); llvm::Value* stack_args = MakeValue(args[1]); llvm::Value* packed_args = builder_->CreateInBoundsGEP( t_tvm_ffi_any_, builder_->CreatePointerCast(stack_args, llvmGetPointerTo(t_tvm_ffi_any_, 0)), @@ -922,7 +923,7 @@ CodeGenCPU::PackedCall CodeGenCPU::MakeCallPackedLowered(const ffi::Arrayargs.size(), 4U); + TVM_FFI_ICHECK_EQ(op->args.size(), 4U); bool use_string_lookup = op->op.same_as(builtin::tvm_call_packed_lowered()); PackedCall pc = MakeCallPackedLowered(op->args, op->dtype, op->args[2].as()->value, op->args[3].as()->value, use_string_lookup); @@ -930,7 +931,7 @@ llvm::Value* CodeGenCPU::CreateCallPacked(const CallNode* op) { } llvm::Value* CodeGenCPU::CreateCallTracePacked(const CallNode* op) { - ICHECK_EQ(op->args.size(), 5U); + TVM_FFI_ICHECK_EQ(op->args.size(), 5U); PackedCall pc = MakeCallPackedLowered(op->args, op->dtype, op->args[2].as()->value, op->args[3].as()->value, true); llvm::LLVMContext* ctx = llvm_target_->GetContext(); @@ -1015,7 +1016,7 @@ llvm::Value* CodeGenCPU::CreateIntrinsic(const CallNode* op) { builder_->SetInsertPoint(new_bb); return ConstInt32(-1); } else if (op->op.same_as(builtin::tvm_struct_get())) { - ICHECK_EQ(op->args.size(), 3U); + TVM_FFI_ICHECK_EQ(op->args.size(), 3U); int kind = op->args[2].as().value()->value; TypedPointer ref = CreateStructRefPtr(op->dtype, MakeValue(op->args[0]), MakeValue(op->args[1]), kind); @@ -1031,12 +1032,12 @@ llvm::Value* CodeGenCPU::CreateIntrinsic(const CallNode* op) { return struct_value; } else if (op->op.same_as(builtin::tvm_struct_set())) { - ICHECK_EQ(op->args.size(), 4U); + TVM_FFI_ICHECK_EQ(op->args.size(), 4U); int kind = op->args[2].as().value()->value; llvm::Value* value = MakeValue(op->args[3]); TypedPointer ref = CreateStructRefPtr(op->args[3].dtype(), MakeValue(op->args[0]), MakeValue(op->args[1]), kind); - ICHECK(kind != builtin::kArrAddr); + TVM_FFI_ICHECK(kind != builtin::kArrAddr); if (value->getType()->isPointerTy()) { value = builder_->CreatePointerCast(value, ref.type); } @@ -1053,11 +1054,11 @@ llvm::Value* CodeGenCPU::CreateIntrinsic(const CallNode* op) { builder_->CreateStore(value, ref.addr); return ConstInt32(0); } else if (op->op.same_as(builtin::tvm_stack_alloca())) { - ICHECK_EQ(op->args.size(), 2U); + TVM_FFI_ICHECK_EQ(op->args.size(), 2U); std::string type = op->args[0].as().value()->value; return WithFunctionEntry([&]() -> llvm::AllocaInst* { const int64_t* pval = as_const_int(op->args[1]); - ICHECK(pval) << "require stack alloca to contain constant value"; + TVM_FFI_ICHECK(pval) << "require stack alloca to contain constant value"; llvm::Value* num = ConstInt32(pval[0]); if (type == "shape") { return builder_->CreateAlloca(t_tvm_shape_index_, num); @@ -1070,7 +1071,7 @@ llvm::Value* CodeGenCPU::CreateIntrinsic(const CallNode* op) { alloca->setAlignment(llvm::Align(64)); return alloca; } else { - LOG(FATAL) << "Unknown stack alloca type " << type; + TVM_FFI_THROW(InternalError) << "Unknown stack alloca type " << type; } }); } else { @@ -1111,21 +1112,22 @@ void CodeGenCPU::VisitStmt_(const AttrStmtNode* op) { EmitDebugLocation(op); if (op->attr_key == tir::attr::coproc_uop_scope) { const StringImmNode* value = op->value.as(); - ICHECK(value != nullptr); + TVM_FFI_ICHECK(value != nullptr); this->CreateStaticInit(value->value, op->body); } else if (op->attr_key == tir::attr::compute_scope) { this->CreateComputeScope(op); } else if (tir::attr::IsPragmaKey(op->attr_key)) { if (op->attr_key == "pragma_parallel_stride_pattern") { - ICHECK(parallel_env_.penv != nullptr) + TVM_FFI_ICHECK(parallel_env_.penv != nullptr) << "Pragma parallel_stride_pattern only valid in parallel launch"; parallel_env_.stride_pattern = true; this->VisitStmt(op->body); } else if (op->attr_key == "pragma_parallel_launch_point") { CreateParallelLaunch(op->body, 0, "pragma_parallel"); } else if (op->attr_key == "pragma_parallel_barrier_when_finish") { - ICHECK(parallel_env_.penv != nullptr) << "Cannot run barrier without parallel environment"; - ICHECK(!parallel_env_.in_parallel_loop) + TVM_FFI_ICHECK(parallel_env_.penv != nullptr) + << "Cannot run barrier without parallel environment"; + TVM_FFI_ICHECK(!parallel_env_.in_parallel_loop) << "Cannot not place within parallel loop as the workload may differ, " << " place it between parallel and parallel_launch_point"; this->VisitStmt(op->body); @@ -1138,7 +1140,7 @@ void CodeGenCPU::VisitStmt_(const AttrStmtNode* op) { builder_->CreateCall(bar_callee, {MakeValue(parallel_env_.task_id), parallel_env_.penv}); } else if (op->attr_key == tir::attr::pragma_import_llvm) { const StringImmNode* value = op->value.as(); - ICHECK(value != nullptr); + TVM_FFI_ICHECK(value != nullptr); this->HandleImport(value->value); this->VisitStmt(op->body); } else { @@ -1155,21 +1157,23 @@ void CodeGenCPU::VisitStmt_(const ForNode* op) { if (op->kind == ForKind::kSerial || op->kind == ForKind::kUnrolled) { CodeGenLLVM::VisitStmt_(op); } else if (op->kind == ForKind::kParallel) { - ICHECK(is_zero(op->min)) << "Parallel launch require canonical loop with zero start index"; - ICHECK(op->HasTrivialStep()) << "Parallel launch require canonical loop with trivial loop step"; + TVM_FFI_ICHECK(is_zero(op->min)) + << "Parallel launch require canonical loop with zero start index"; + TVM_FFI_ICHECK(op->HasTrivialStep()) + << "Parallel launch require canonical loop with trivial loop step"; if (parallel_env_.penv == nullptr) { auto copy_node = For(ffi::make_object(*op)); CreateParallelLaunch(copy_node, 0, std::string("loop_parallel_") + op->loop_var->name_hint.c_str()); } else { // already in parallel env. - ICHECK(parallel_env_.task_id.defined()); - ICHECK(parallel_env_.num_task.defined()); - ICHECK(parallel_env_.penv != nullptr); + TVM_FFI_ICHECK(parallel_env_.task_id.defined()); + TVM_FFI_ICHECK(parallel_env_.num_task.defined()); + TVM_FFI_ICHECK(parallel_env_.penv != nullptr); DataType t = op->extent.dtype(); PrimExpr num_task = cast(t, parallel_env_.num_task); PrimExpr task_id = cast(t, parallel_env_.task_id); - ICHECK(!parallel_env_.in_parallel_loop) + TVM_FFI_ICHECK(!parallel_env_.in_parallel_loop) << "Nested parallel loop is not supported by threadpool, try fuse them instead"; parallel_env_.in_parallel_loop = true; PrimExpr end = is_zero(op->min) ? op->extent : analyzer_->Simplify(op->min + op->extent); @@ -1187,7 +1191,7 @@ void CodeGenCPU::VisitStmt_(const ForNode* op) { ++parallel_env_.parallel_loop_count; } } else { - LOG(FATAL) << "cannot handle for type " << op->kind; + TVM_FFI_THROW(InternalError) << "cannot handle for type " << op->kind; } } diff --git a/src/target/llvm/codegen_hexagon.cc b/src/target/llvm/codegen_hexagon.cc index a546ad2019f2..61922a6342de 100644 --- a/src/target/llvm/codegen_hexagon.cc +++ b/src/target/llvm/codegen_hexagon.cc @@ -129,18 +129,18 @@ void CodeGenHexagon::InitTarget() { #if TVM_LLVM_VERSION >= 180 if (!fs.starts_with(hvx_length_feature)) continue; - ICHECK(fs.ends_with("b")) << "malformed target feature: " << f; + TVM_FFI_ICHECK(fs.ends_with("b")) << "malformed target feature: " << f; #else if (!fs.startswith(hvx_length_feature)) continue; - ICHECK(fs.endswith("b")) << "malformed target feature: " << f; + TVM_FFI_ICHECK(fs.endswith("b")) << "malformed target feature: " << f; #endif int hvx_bytes = 0; size_t len_begin = std::strlen(hvx_length_feature); - ICHECK(!fs.substr(len_begin, fs.size() - len_begin - 1).getAsInteger(10, hvx_bytes)) + TVM_FFI_ICHECK(!fs.substr(len_begin, fs.size() - len_begin - 1).getAsInteger(10, hvx_bytes)) << "invalid HVX length in feature string: " << f; - ICHECK(hvx_bytes == 64 || hvx_bytes == 128) + TVM_FFI_ICHECK(hvx_bytes == 64 || hvx_bytes == 128) << "invalid HVX vector length: " << hvx_bytes << ", should be 64 or 128"; native_vector_bits_ = hvx_bytes * 8; // There should only be one hvx-length... @@ -278,8 +278,9 @@ CodeGenLLVM::TypedPointer CodeGenHexagon::CreateBufferPtr(llvm::Value* buffer_pt return CodeGenCPU::CreateBufferPtr(buffer_ptr, buffer_element_dtype, indices, value_dtype); } - ICHECK_EQ(indices.size(), 2) << "CodegenHexagon supports 1-d and 2-d physical buffers, received " - << indices.size() << "-d buffer indices"; + TVM_FFI_ICHECK_EQ(indices.size(), 2) + << "CodegenHexagon supports 1-d and 2-d physical buffers, received " << indices.size() + << "-d buffer indices"; // Use the first index to identify the pointer. DataType dtype_void_ptr = DataType::Handle(); @@ -309,7 +310,7 @@ llvm::Value* CodeGenHexagon::Intrinsic(llvm::Intrinsic::ID IntID, #endif std::vector conv_args; llvm::FunctionType* intf_type = intf->getFunctionType(); - ICHECK(args.size() == intf_type->getNumParams()); + TVM_FFI_ICHECK(args.size() == intf_type->getNumParams()); for (int i = 0, e = args.size(); i != e; ++i) { llvm::Value* arg = args[i]; @@ -397,7 +398,7 @@ llvm::Value* CodeGenHexagon::VectorLookupLoad(Buffer buffer, DataType buffer_typ int res_bits = GetTypeSizeInBits(res_type); int ret_bits = GetTypeSizeInBits(ret_type); - ICHECK_GE(res_bits, ret_bits); + TVM_FFI_ICHECK_GE(res_bits, ret_bits); if (ret_bits < res_bits) { #if TVM_LLVM_VERSION >= 110 llvm::Type* res_byte_type = llvm::VectorType::get(t_int8_, res_bits / 8, /*Scalable*/ false); @@ -497,7 +498,7 @@ ffi::Module BuildHexagon(IRModule mod, Target target) { auto f = Downcast(kv.second); if (f->HasNonzeroAttr(tir::attr::kIsEntryFunc)) { auto global_symbol = f->GetAttr(tvm::attr::kGlobalSymbol); - ICHECK(global_symbol.has_value()); + TVM_FFI_ICHECK(global_symbol.has_value()); entry_func = global_symbol.value(); } } @@ -539,7 +540,8 @@ ffi::Module BuildHexagon(IRModule mod, Target target) { std::unique_ptr cm = llvm::CloneModule(m); llvm::legacy::PassManager pass; llvm::TargetMachine* tm = llvm_target->GetOrCreateTargetMachine(); - ICHECK(tm->addPassesToEmitFile(pass, os, nullptr, ft) == 0) << "Cannot emit target code"; + TVM_FFI_ICHECK(tm->addPassesToEmitFile(pass, os, nullptr, ft) == 0) + << "Cannot emit target code"; pass.run(*cm.get()); out.assign(ss.c_str(), ss.size()); } @@ -551,10 +553,10 @@ ffi::Module BuildHexagon(IRModule mod, Target target) { llvm::SmallString<64> file_name; int fd; std::error_code ec = llvm::sys::fs::createTemporaryFile("tvm", suffix, fd, file_name); - ICHECK_EQ(static_cast(ec), false) << ec.message(); + TVM_FFI_ICHECK_EQ(static_cast(ec), false) << ec.message(); llvm::raw_fd_ostream file(fd, true); file << data; - ICHECK(!file.has_error()) << file.error().message(); + TVM_FFI_ICHECK(!file.has_error()) << file.error().message(); // If there is an error, execution will never get here, but return // {ec, name} anyway to allow caller to handle error conditions. // This way the "ICHECK" above can be removed with minimal effort. @@ -571,23 +573,23 @@ ffi::Module BuildHexagon(IRModule mod, Target target) { so_name += "so"; const auto f = tvm::ffi::Function::GetGlobal("tvm.contrib.hexagon.link_shared"); - ICHECK(f.has_value()) << "tvm.contrib.hexagon.link_shared does not to exist, " - "do import tvm.contrib.hexagon"; + TVM_FFI_ICHECK(f.has_value()) << "tvm.contrib.hexagon.link_shared does not to exist, " + "do import tvm.contrib.hexagon"; ffi::Array o_names = {StringImm(o_name)}; ffi::Map extra_args; if (target->attrs.count("mcpu")) { std::string mcpu = Downcast(target->attrs.at("mcpu")); #if TVM_LLVM_VERSION >= 180 - ICHECK(llvm::StringRef(mcpu).starts_with("hexagon")) + TVM_FFI_ICHECK(llvm::StringRef(mcpu).starts_with("hexagon")) #else - ICHECK(llvm::StringRef(mcpu).startswith("hexagon")) + TVM_FFI_ICHECK(llvm::StringRef(mcpu).startswith("hexagon")) #endif << "unexpected -mcpu value in target:" << mcpu; extra_args.Set("hex_arch", llvm::StringRef(mcpu).drop_front(strlen("hexagon")).str()); } int rc = (*f)(so_name, o_names, extra_args).cast(); - ICHECK(rc == 0) << "Failed to link " << so_name; + TVM_FFI_ICHECK(rc == 0) << "Failed to link " << so_name; return HexagonModuleCreate(so_name, "so", ExtractFuncInfo(mod), asm_str, obj_str, ir_str, bc_str); } diff --git a/src/target/llvm/codegen_llvm.cc b/src/target/llvm/codegen_llvm.cc index b7004dec32e2..ccb9173cd1b5 100644 --- a/src/target/llvm/codegen_llvm.cc +++ b/src/target/llvm/codegen_llvm.cc @@ -128,12 +128,12 @@ std::unique_ptr CodeGenLLVM::Create(LLVMTarget* llvm_target) { } else if (auto f = tvm::ffi::Function::GetGlobal(factory_template + "cpu")) { handle = (*f)().cast(); } else { - LOG(FATAL) << "no factory function for codegen for target " << target; + TVM_FFI_THROW(InternalError) << "no factory function for codegen for target " << target; } if (handle) { return std::unique_ptr(static_cast(handle)); } else { - LOG(FATAL) << "unable to create codegen for target " << target; + TVM_FFI_THROW(InternalError) << "unable to create codegen for target " << target; } } @@ -212,7 +212,7 @@ void CodeGenLLVM::InitTarget() { os << "}\n"; auto mod = llvm_target_->GetInstance().ParseIR(os.str()); auto* test_sse2 = mod->getFunction(fname); - ICHECK(test_sse2 != nullptr) << "Module creation error"; + TVM_FFI_ICHECK(test_sse2 != nullptr) << "Module creation error"; use_float16_abi = tm->getSubtargetImpl(*test_sse2)->checkFeatures("+sse2"); } #endif // TVM_LLVM_VERSION >= 150 @@ -260,7 +260,7 @@ llvm::Function* CodeGenLLVM::DeclareFunctionInternal(const GlobalVar& gvar, cons return it->second; } - ICHECK_EQ(func->buffer_map.size(), 0U) + TVM_FFI_ICHECK_EQ(func->buffer_map.size(), 0U) << "Cannot codegen function with buffer_map, please lower them first"; std::vector param_types; @@ -348,15 +348,16 @@ void CodeGenLLVM::AddFunctionInternal(const GlobalVar& gvar, const PrimFunc& f) void CodeGenLLVM::Verify() const { std::string verify_errors_storage; llvm::raw_string_ostream verify_errors(verify_errors_storage); - LOG_IF(FATAL, llvm::verifyModule(*module_, &verify_errors)) - << "LLVM module verification failed with the following errors: \n" - << verify_errors.str(); + if (llvm::verifyModule(*module_, &verify_errors)) { + TVM_FFI_THROW(InternalError) << "LLVM module verification failed with the following errors: \n" + << verify_errors.str(); + } } std::unique_ptr CodeGenLLVM::Finish() { this->AddStartupFunction(); for (size_t i = 0; i < link_modules_.size(); ++i) { - ICHECK(!llvm::Linker::linkModules(*module_, std::move(link_modules_[i]))) + TVM_FFI_ICHECK(!llvm::Linker::linkModules(*module_, std::move(link_modules_[i]))) << "Failed to link modules"; } link_modules_.clear(); @@ -401,12 +402,16 @@ void CodeGenLLVM::AddLinkModule(std::unique_ptr&& mod) { } void CodeGenLLVM::AddMainFunction(const std::string& entry_func_name) { - LOG(FATAL) << "not implemented"; + TVM_FFI_THROW(InternalError) << "not implemented"; } -llvm::Value* CodeGenLLVM::GetThreadIndex(const IterVar& iv) { LOG(FATAL) << "not implemented"; } +llvm::Value* CodeGenLLVM::GetThreadIndex(const IterVar& iv) { + TVM_FFI_THROW(InternalError) << "not implemented"; +} -llvm::Value* CodeGenLLVM::CreateStorageSync(const CallNode* op) { LOG(FATAL) << "not implemented"; } +llvm::Value* CodeGenLLVM::CreateStorageSync(const CallNode* op) { + TVM_FFI_THROW(InternalError) << "not implemented"; +} #if TVM_LLVM_VERSION >= 160 @@ -479,8 +484,8 @@ void CodeGenLLVM::Optimize() { mpass.addPass(llvm::VerifierPass()); } if (auto err = builder.parsePassPipeline(mpass, pipeline)) { - LOG(FATAL) << "error parsing pass pipeline '" << pipeline - << "':" << llvm::toString(std::move(err)) << '\n'; + TVM_FFI_THROW(InternalError) << "error parsing pass pipeline '" << pipeline + << "':" << llvm::toString(std::move(err)) << '\n'; } mpass.run(*module_, mam); @@ -567,7 +572,7 @@ unsigned CodeGenLLVM::GetGlobalAddressSpace() const { return 0; } llvm::Type* CodeGenLLVM::DTypeToLLVMType(const DataType& dtype) const { if (dtype.is_handle()) { - ICHECK_EQ(dtype.lanes(), 1); + TVM_FFI_ICHECK_EQ(dtype.lanes(), 1); return t_void_p_; } if (dtype.is_void()) { @@ -591,7 +596,7 @@ llvm::Type* CodeGenLLVM::DTypeToLLVMType(const DataType& dtype) const { etype = llvm::Type::getDoubleTy(*ctx); break; default: - LOG(FATAL) << "do not support " << dtype; + TVM_FFI_THROW(InternalError) << "do not support " << dtype; } } else if (dtype.code() == DataType::kFloat8_e3m4 || dtype.code() == DataType::kFloat8_e4m3 || dtype.code() == DataType::kFloat8_e4m3b11fnuz || @@ -613,7 +618,7 @@ llvm::Type* CodeGenLLVM::DTypeToLLVMType(const DataType& dtype) const { return llvm::FixedVectorType::get(etype, dtype.lanes()); } #else - ICHECK(!dtype.is_scalable_vector()) + TVM_FFI_ICHECK(!dtype.is_scalable_vector()) << "Versions of LLVM < 11 do not support scalable vectors. Please upgrade to a later " "version."; return llvm::VectorType::get(etype, dtype.lanes()); @@ -644,7 +649,7 @@ llvm::Type* CodeGenLLVM::GetLLVMType(const Type& type) const { } else if (type->IsInstance()) { return t_tvm_tensormap_; } else { - LOG(FATAL) << "Type " << type << " does not have a corresponding LLVM Type"; + TVM_FFI_THROW(InternalError) << "Type " << type << " does not have a corresponding LLVM Type"; } } @@ -787,7 +792,7 @@ void CodeGenLLVM::PopLoopFrame() { loop_frame_jump_tgts_.pop_back(); } llvm::Value* CodeGenLLVM::CreateVecSlice(llvm::Value* vec, int begin, int extent) { int num_elems = GetVectorNumElements(vec); if (extent == num_elems && begin == 0) return vec; - ICHECK(begin >= 0 && extent <= num_elems) << "Slicing out of bound!\n"; + TVM_FFI_ICHECK(begin >= 0 && extent <= num_elems) << "Slicing out of bound!\n"; std::vector indices; indices.reserve(extent); for (int i = 0; i < extent; ++i) { @@ -817,7 +822,7 @@ llvm::Value* CodeGenLLVM::CreateVecPad(llvm::Value* vec, int target_lanes) { llvm::Value* mask = llvm::UndefValue::get(DTypeToLLVMType(DataType::Int(32, target_lanes))); int num_elems = GetVectorNumElements(vec); if (num_elems == target_lanes) return vec; - ICHECK_LT(num_elems, target_lanes); + TVM_FFI_ICHECK_LT(num_elems, target_lanes); for (int i = 0; i < num_elems; ++i) { mask = builder_->CreateInsertElement(mask, ConstInt32(i), ConstInt32(i)); } @@ -894,7 +899,7 @@ void CodeGenLLVM::CreateSerialFor(llvm::Value* begin, llvm::Value* end, llvm::Va llvm::PHINode* loop_value = builder_->CreatePHI(begin->getType(), 2); AddDebugInformation(loop_value, loop_var); loop_value->addIncoming(begin, pre_block); - ICHECK(!var_map_.count(loop_var.get())); + TVM_FFI_ICHECK(!var_map_.count(loop_var.get())); var_map_[loop_var.get()] = loop_value; auto lt = CreateLT(loop_var.dtype(), loop_value, end); @@ -920,8 +925,8 @@ llvm::Value* CodeGenLLVM::CreateCast(DataType from, DataType to, llvm::Value* va llvm::Type* target = DTypeToLLVMType(to); if (value->getType() == target) return value; // TODO(tvm-team): consider add native support - ICHECK(!from.is_bfloat16()) << "BF16 needs to be storaged lowered first"; - ICHECK(!to.is_bfloat16()) << "BF16 needs to be storaged lowered first"; + TVM_FFI_ICHECK(!from.is_bfloat16()) << "BF16 needs to be storaged lowered first"; + TVM_FFI_ICHECK(!to.is_bfloat16()) << "BF16 needs to be storaged lowered first"; if (to.is_handle()) { return builder_->CreateBitCast(value, target); @@ -949,7 +954,7 @@ llvm::Value* CodeGenLLVM::CreateCast(DataType from, DataType to, llvm::Value* va } else if ((from.is_uint() || from.is_bool()) && to.is_float()) { return builder_->CreateUIToFP(value, target); } else { - ICHECK(from.is_float() && to.is_float()); + TVM_FFI_ICHECK(from.is_float() && to.is_float()); return builder_->CreateFPCast(value, target); } } @@ -986,11 +991,12 @@ CodeGenLLVM::TypedPointer CodeGenLLVM::CreateBufferPtr(llvm::Value* buffer_ptr, DataType buffer_element_dtype, llvm::ArrayRef indices, DataType value_dtype) { - ICHECK_EQ(indices.size(), 1) << "CodeGenLLVM requires all buffers to be flat 1-d buffers."; + TVM_FFI_ICHECK_EQ(indices.size(), 1) + << "CodeGenLLVM requires all buffers to be flat 1-d buffers."; llvm::Value* index = indices[0]; llvm::PointerType* buffer_ptr_type = llvm::dyn_cast(buffer_ptr->getType()); - ICHECK(buffer_ptr_type != nullptr); + TVM_FFI_ICHECK(buffer_ptr_type != nullptr); auto address_space = buffer_ptr_type->getAddressSpace(); llvm::Type* element_type = DTypeToLLVMType(buffer_element_dtype); @@ -999,12 +1005,12 @@ CodeGenLLVM::TypedPointer CodeGenLLVM::CreateBufferPtr(llvm::Value* buffer_ptr, llvm::Type* value_type = DTypeToLLVMType(value_dtype); llvm::PointerType* value_ptr_type = llvmGetPointerTo(value_type, address_space); - ICHECK(index->getType()->isIntegerTy()) << "Expected buffer index to be an integer"; + TVM_FFI_ICHECK(index->getType()->isIntegerTy()) << "Expected buffer index to be an integer"; if (buffer_ptr_type != element_ptr_type) { buffer_ptr = builder_->CreatePointerCast(buffer_ptr, element_ptr_type); } - ICHECK(!HasAlignmentPadding(buffer_element_dtype)) + TVM_FFI_ICHECK(!HasAlignmentPadding(buffer_element_dtype)) << "DType " << buffer_element_dtype << " has padding for alignment. TVM data arrays are expected to be densely packed, with no " "padding for alignment."; @@ -1019,7 +1025,7 @@ CodeGenLLVM::TypedPointer CodeGenLLVM::CreateBufferPtr(llvm::Value* buffer_ptr, llvm::Value* CodeGenLLVM::GetVarValue(const VarNode* v) const { auto it = var_map_.find(v); - ICHECK(it != var_map_.end()) << "cannot find variable " << v->name_hint; + TVM_FFI_ICHECK(it != var_map_.end()) << "cannot find variable " << v->name_hint; return it->second; } @@ -1374,7 +1380,7 @@ void CodeGenLLVM::EmitFloat16ConversionBuiltins(bool use_float16_abi) { llvm::Value* CodeGenLLVM::CreateIntrinsic(const CallNode* op) { if (op->op.same_as(builtin_call_llvm_intrin_) || op->op.same_as(builtin_call_llvm_pure_intrin_)) { - ICHECK_GE(op->args.size(), 1U); + TVM_FFI_ICHECK_GE(op->args.size(), 1U); llvm::Intrinsic::ID id = static_cast(Downcast(op->args[0])->value); std::vector arg_value; std::vector arg_type; @@ -1384,8 +1390,8 @@ llvm::Value* CodeGenLLVM::CreateIntrinsic(const CallNode* op) { } llvm::Type* return_type = GetLLVMType(ffi::GetRef(op)); llvm::Function* f = GetIntrinsicDecl(id, return_type, arg_type); - ICHECK(f) << "Cannot find intrinsic declaration, possible type mismatch: " - << llvmGetIntrinName(id); + TVM_FFI_ICHECK(f) << "Cannot find intrinsic declaration, possible type mismatch: " + << llvmGetIntrinName(id); // In earlier versions of LLVM's, the prefetch intrinsic is not // overloaded, and always takes the first argument as i8*. If // this is the case, this argument should insert a cast to i8*. @@ -1419,7 +1425,7 @@ llvm::Value* CodeGenLLVM::CreateIntrinsic(const CallNode* op) { return CreateStorageSync(op); } else if (op->op.same_as(builtin::address_of())) { const BufferLoadNode* load = op->args[0].as(); - ICHECK(op->args.size() == 1 && load); + TVM_FFI_ICHECK(op->args.size() == 1 && load); ffi::Array indices = load->indices; if (const RampNode* r = indices[indices.size() - 1].as()) { @@ -1443,13 +1449,14 @@ llvm::Value* CodeGenLLVM::CreateIntrinsic(const CallNode* op) { llvm::Value* offset = MakeValue(op->args[1]); return builder_->CreateInBoundsGEP(t_int8_, ptr, offset); } else if (op->op.same_as(builtin::large_uint_imm())) { - ICHECK_EQ(op->args.size(), 2U); + TVM_FFI_ICHECK_EQ(op->args.size(), 2U); uint64_t low = static_cast(Downcast(op->args[0])->value); uint64_t high = static_cast(Downcast(op->args[1])->value); uint64_t val = (high << 32U) | low; return llvm::ConstantInt::get(DTypeToLLVMType(op->dtype), val); } else if (op->op.same_as(builtin::if_then_else())) { - ICHECK_EQ(op->args[0].dtype().lanes(), 1) << "if_then_else can only take scalar condition"; + TVM_FFI_ICHECK_EQ(op->args[0].dtype().lanes(), 1) + << "if_then_else can only take scalar condition"; llvm::LLVMContext* ctx = llvm_target_->GetContext(); auto* then_block = llvm::BasicBlock::Create(*ctx, "if_then", function_); auto* else_block = llvm::BasicBlock::Create(*ctx, "if_else", function_); @@ -1470,10 +1477,10 @@ llvm::Value* CodeGenLLVM::CreateIntrinsic(const CallNode* op) { return value; } else if (op->op.same_as(builtin::ret())) { auto const* val = op->args[0].as(); - ICHECK(val) << "the tir.ret should be transformed to return zero " - << "before the llvm code generation."; - ICHECK_EQ(val->value, 0) << "the tir.ret should be transformed to " - << "return zero before the llvm code generation."; + TVM_FFI_ICHECK(val) << "the tir.ret should be transformed to return zero " + << "before the llvm code generation."; + TVM_FFI_ICHECK_EQ(val->value, 0) << "the tir.ret should be transformed to " + << "return zero before the llvm code generation."; builder_->CreateRet(ConstInt32(0)); // LLVM allows exactly one terminator in a single basic block // append a new dummy basic block to avoid error. @@ -1482,7 +1489,7 @@ llvm::Value* CodeGenLLVM::CreateIntrinsic(const CallNode* op) { builder_->SetInsertPoint(ret_dummy); return ret_dummy; } else if (op->op.same_as(builtin::continue_loop())) { - ICHECK(!loop_frame_jump_tgts_.empty()) + TVM_FFI_ICHECK(!loop_frame_jump_tgts_.empty()) << "the tir.continue_loop should be inserted under at least one For or While stmts."; builder_->CreateBr(loop_frame_jump_tgts_.back().first); // LLVM allows exactly one terminator in a single basic block @@ -1492,7 +1499,7 @@ llvm::Value* CodeGenLLVM::CreateIntrinsic(const CallNode* op) { builder_->SetInsertPoint(post_dummy); return post_dummy; } else if (op->op.same_as(builtin::break_loop())) { - ICHECK(!loop_frame_jump_tgts_.empty()) + TVM_FFI_ICHECK(!loop_frame_jump_tgts_.empty()) << "the tir.break_loop should be inserted under at least one For or While stmts."; builder_->CreateBr(loop_frame_jump_tgts_.back().second); // LLVM allows exactly one terminator in a single basic block @@ -1531,7 +1538,7 @@ llvm::Value* CodeGenLLVM::CreateIntrinsic(const CallNode* op) { return builder_->CreateShuffleVector(v0, v1, indices); } else if (op->op.same_as(builtin::atomic_add())) { // TODO(masahi): Support atomic for CPU backend - LOG(FATAL) << "CPU backend does not support atomic add yet."; + TVM_FFI_THROW(InternalError) << "CPU backend does not support atomic add yet."; } else if (op->op.same_as(builtin::start_profile_intrinsic()) || op->op.same_as(builtin::end_profile_intrinsic())) { LOG(INFO) << "Ignoring profile_intrinsic ... " << op->op; @@ -1553,7 +1560,7 @@ llvm::Value* CodeGenLLVM::CreateIntrinsic(const CallNode* op) { return builder_->CreateCall(f, {MakeValue(op->args[0]), MakeValue(op->args[1])}); #endif } else { - LOG(FATAL) << "unknown intrinsic " << op->op; + TVM_FFI_THROW(InternalError) << "unknown intrinsic " << op->op; } } @@ -1602,7 +1609,7 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const StringImmNode* op) { return GetConstS return builder_->Create##Op(a, b); \ } \ } else { \ - ICHECK(t.is_float()); \ + TVM_FFI_ICHECK(t.is_float()); \ return builder_->CreateF##Op(a, b); \ } \ } \ @@ -1621,7 +1628,7 @@ DEFINE_CODEGEN_BINARY_OP(Mul); } else if (t.is_uint()) { \ return builder_->CreateICmpU##Op(a, b); \ } else { \ - ICHECK(t.is_float()); \ + TVM_FFI_ICHECK(t.is_float()); \ return builder_->CreateFCmpO##Op(a, b); \ } \ } \ @@ -1642,7 +1649,7 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const DivNode* op) { } else if (op->dtype.is_uint()) { return builder_->CreateUDiv(a, b); } else { - ICHECK(op->dtype.is_float()); + TVM_FFI_ICHECK(op->dtype.is_float()); return builder_->CreateFDiv(a, b); } } @@ -1655,7 +1662,7 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const ModNode* op) { } else if (op->dtype.is_uint()) { return builder_->CreateURem(a, b); } else { - ICHECK(op->dtype.is_float()); + TVM_FFI_ICHECK(op->dtype.is_float()); return builder_->CreateFRem(a, b); } } @@ -1712,7 +1719,7 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const SelectNode* op) { llvm::Value* CodeGenLLVM::VisitExpr_(const LetNode* op) { auto it = let_binding_.find(op->var); if (it != let_binding_.end()) { - ICHECK(deep_equal_(it->second->value, op->value)) + TVM_FFI_ICHECK(deep_equal_(it->second->value, op->value)) << "Let cannot bind the same var to two different values"; } else { let_binding_[op->var] = op; @@ -1739,7 +1746,7 @@ void CodeGenLLVM::BufferAccessHelper( make_instruction) { DataType buffer_element_dtype = buffer->dtype; - ICHECK_GE(indices.size(), 1) + TVM_FFI_ICHECK_GE(indices.size(), 1) << "Buffer " << buffer->name << " is accessed with no indices. " << "0-d scalar buffers are expected to be flattened to 1-d buffers prior to codegen."; @@ -1749,15 +1756,15 @@ void CodeGenLLVM::BufferAccessHelper( // requires 1-d indices. std::vector earlier_index_values; for (size_t i = 0; i < indices.size() - 1; i++) { - ICHECK_EQ(indices[i].dtype().lanes(), 1) + TVM_FFI_ICHECK_EQ(indices[i].dtype().lanes(), 1) << "Buffer " << buffer->name << " is accessed with a multi-lane index at position " << i << ". Multi-lane indices are only supported as the last index."; earlier_index_values.push_back(MakeValue(indices[i])); } PrimExpr last_index = indices[indices.size() - 1]; - ICHECK_EQ(value_dtype.get_lanes_or_vscale_factor(), - last_index.dtype().get_lanes_or_vscale_factor() * buffer_element_dtype.lanes()); + TVM_FFI_ICHECK_EQ(value_dtype.get_lanes_or_vscale_factor(), + last_index.dtype().get_lanes_or_vscale_factor() * buffer_element_dtype.lanes()); // Record index and elemtype in original form used for alias info PrimExpr last_index_origin = last_index; @@ -1791,7 +1798,7 @@ void CodeGenLLVM::BufferAccessHelper( } else { // Otherwise, alignment is based on the return value's scalar // type. - ICHECK_GE(value_dtype.bits(), 8); + TVM_FFI_ICHECK_GE(value_dtype.bits(), 8); alignment = value_dtype.bits() / 8; } @@ -1843,7 +1850,7 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const BufferLoadNode* op) { llvm::Value* predicate, int alignment, bool is_volatile) { llvm::Instruction* load = nullptr; if (predicate != nullptr) { - ICHECK(!is_volatile) + TVM_FFI_ICHECK(!is_volatile) << "The masked load intrinsic does not support declaring load as volatile."; #if TVM_LLVM_VERSION >= 130 load = builder_->CreateMaskedLoad(buffer_ptr.type, buffer_ptr.addr, llvm::Align(alignment), @@ -1889,7 +1896,7 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const CallNode* op) { auto call_op = opt_call_op.value(); if (op->op.same_as(builtin_call_extern_) || op->op.same_as(builtin_call_pure_extern_)) { // call extern intrinsic - ICHECK_GE(op->args.size(), 1U); + TVM_FFI_ICHECK_GE(op->args.size(), 1U); auto global_symbol = Downcast(op->args[0]); return this->CreateCallExtern(GetType(ffi::GetRef(op)), global_symbol->value, op->args, true); @@ -1906,7 +1913,7 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const CallNode* op) { } else if (auto* ptr_gvar = op->op.as()) { auto gvar = ffi::GetRef(ptr_gvar); auto it = functions_.find(ptr_gvar); - ICHECK(it != functions_.end()) << "Call to undefined GlobalVar \"" << gvar << "\""; + TVM_FFI_ICHECK(it != functions_.end()) << "Call to undefined GlobalVar \"" << gvar << "\""; llvm::Function* callee = it->second; std::vector arg_value; for (const auto& arg : op->args) { @@ -1915,14 +1922,14 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const CallNode* op) { return builder_->CreateCall(callee, arg_value); } else { - LOG(FATAL) << "Unsupported operation in CallNode: " << op->op; + TVM_FFI_THROW(InternalError) << "Unsupported operation in CallNode: " << op->op; } } llvm::Value* CodeGenLLVM::VisitExpr_(const RampNode* op) { llvm::Value* vec = llvm::UndefValue::get(DTypeToLLVMType(op->dtype)); // TODO(ekalda): P4 in https://github.com/apache/tvm/issues/16455 - ICHECK(!op->dtype.is_scalable_vector()); + TVM_FFI_ICHECK(!op->dtype.is_scalable_vector()); int lanes = op->dtype.lanes(); for (int i = 0; i < lanes; ++i) { vec = builder_->CreateInsertElement( @@ -1942,8 +1949,9 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const ShuffleNode* op) { std::vector idx(op->indices.size()); for (int i = 0, e = op->indices.size(); i < e; ++i) { const int64_t* val = as_const_int(op->indices[i]); - ICHECK(val && *val >= 0 && *val < total_lanes) << "Shuffled indeces are suppose to be int, " - << "but get " << op->indices[i] << "\n"; + TVM_FFI_ICHECK(val && *val >= 0 && *val < total_lanes) + << "Shuffled indeces are suppose to be int, " + << "but get " << op->indices[i] << "\n"; idx[i] = *val; } llvm::Value* mask = llvm::ConstantDataVector::get(builder_->getContext(), idx); @@ -1971,7 +1979,7 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const BroadcastNode* op) { #endif llvm::Constant* mask = llvm::ConstantVector::getSplat(ec, zero); #else - ICHECK(!dtype.is_scalable_vector()) + TVM_FFI_ICHECK(!dtype.is_scalable_vector()) << "Versions of LLVM < 11 do not support scalable vectors. Please upgrade to a later " "version."; llvm::Constant* mask = llvm::ConstantVector::getSplat(dtype.lanes(), zero); @@ -1996,7 +2004,7 @@ void CodeGenLLVM::VisitStmt_(const BufferStoreNode* op) { } if (predicate != nullptr) { - ICHECK(!is_volatile) + TVM_FFI_ICHECK(!is_volatile) << "The masked store intrinsic does not support declaring store as volatile."; #if TVM_LLVM_VERSION >= 110 store = @@ -2028,7 +2036,7 @@ void CodeGenLLVM::VisitStmt_(const ForNode* op) { LOG(WARNING) << "Unroll hint get ignore at CodeGenLLVM backend, " << " consider set unroll_explicit=True"; } else { - ICHECK(op->kind == ForKind::kSerial); + TVM_FFI_ICHECK(op->kind == ForKind::kSerial); } PrimExpr step = op->step.value_or(make_const(op->extent->dtype, 1)); PrimExpr end = is_zero(op->min) ? op->extent : analyzer_->Simplify(op->min + op->extent); @@ -2080,15 +2088,15 @@ void CodeGenLLVM::VisitStmt_(const IfThenElseNode* op) { void CodeGenLLVM::VisitStmt_(const AllocateNode* op) { EmitDebugLocation(op); - ICHECK_EQ(op->extents.size(), 1) + TVM_FFI_ICHECK_EQ(op->extents.size(), 1) << "LLVM codegen only supports flat 1-d buffer allocation, but allocation of " << op->buffer_var->name_hint << " is " << op->extents << "-d"; - ICHECK(!is_zero(op->condition)); + TVM_FFI_ICHECK(!is_zero(op->condition)); llvm::Value* buf = nullptr; int32_t constant_size = op->ConstantAllocationSize(); - ICHECK_GT(constant_size, 0) << "Can only handle constant size stack allocation"; + TVM_FFI_ICHECK_GT(constant_size, 0) << "Can only handle constant size stack allocation"; StorageInfo& info = alloc_storage_info_[op->buffer_var.get()]; if (constant_size % 4 == 0 && info.alignment == 0) { info.alignment = GetTempAllocaAlignment(op->dtype, constant_size); @@ -2124,7 +2132,7 @@ void CodeGenLLVM::VisitStmt_(const AllocateNode* op) { buf, llvmGetPointerTo(DTypeToLLVMType(op->dtype), buf->getType()->getPointerAddressSpace())); AddDebugInformation(buf, op->buffer_var); - ICHECK(!var_map_.count(op->buffer_var.get())); + TVM_FFI_ICHECK(!var_map_.count(op->buffer_var.get())); var_map_[op->buffer_var.get()] = buf; this->VisitStmt(op->body); } @@ -2141,7 +2149,7 @@ void CodeGenLLVM::VisitStmt_(const AttrStmtNode* op) { } } else if (op->attr_key == tir::attr::storage_alignment) { const VarNode* v = op->node.as(); - ICHECK(v); + TVM_FFI_ICHECK(v); alloc_storage_info_[v].alignment = static_cast(op->value.as()->value); if (var_map_.count(v) && alloc_storage_info_[v].alignment > 1) { builder_->CreateAlignmentAssumption(*data_layout_, GetVarValue(v), @@ -2149,7 +2157,7 @@ void CodeGenLLVM::VisitStmt_(const AttrStmtNode* op) { } } else if (op->attr_key == tir::attr::volatile_scope) { const VarNode* v = op->node.as(); - ICHECK(v); + TVM_FFI_ICHECK(v); volatile_buf_.insert(v); } this->VisitStmt(op->body); @@ -2165,7 +2173,7 @@ void CodeGenLLVM::VisitStmt_(const AssertStmtNode* op) { void CodeGenLLVM::VisitStmt_(const LetStmtNode* op) { EmitDebugLocation(op); const VarNode* v = op->var.get(); - ICHECK(!var_map_.count(v)); + TVM_FFI_ICHECK(!var_map_.count(v)); if (v->dtype.is_handle()) { if (!is_restricted_) { alias_var_set_.insert(v); @@ -2178,7 +2186,7 @@ void CodeGenLLVM::VisitStmt_(const LetStmtNode* op) { // need to introduce a pointer-cast, even though pointer-to-pointer // casts are not expressible with the `tir::CastNode`. if (v->dtype.is_handle() && v->type_annotation.defined()) { - CHECK(op->value->dtype.is_handle()) + TVM_FFI_ICHECK(op->value->dtype.is_handle()) << "Variable " << op->var << " is a pointer with type " << op->value << ", but is being bound to expression with type " << op->value->dtype; auto* llvm_type = GetLLVMType(v->type_annotation); @@ -2244,9 +2252,9 @@ void CodeGenLLVM::EmitDebugLocation(const StmtNode* op) { EmitDebugLocation(op-> void CodeGenLLVM::AddDebugInformation(llvm::Function* f_llvm, const ffi::Array& tvm_param_types) { #if TVM_LLVM_VERSION >= 50 - ICHECK(di_subprogram_); + TVM_FFI_ICHECK(di_subprogram_); f_llvm->setSubprogram(di_subprogram_); - ICHECK_EQ(f_llvm->getSubprogram(), di_subprogram_); + TVM_FFI_ICHECK_EQ(f_llvm->getSubprogram(), di_subprogram_); IRBuilder builder(&f_llvm->getEntryBlock()); if (!f_llvm->getEntryBlock().empty()) { @@ -2256,7 +2264,7 @@ void CodeGenLLVM::AddDebugInformation(llvm::Function* f_llvm, builder.SetCurrentDebugLocation(DL); llvm::LLVMContext* ctx = llvm_target_->GetContext(); - ICHECK_EQ(f_llvm->arg_size(), tvm_param_types.size()); + TVM_FFI_ICHECK_EQ(f_llvm->arg_size(), tvm_param_types.size()); for (auto iter_param = f_llvm->arg_begin(); iter_param != f_llvm->arg_end(); iter_param++) { size_t i = std::distance(f_llvm->arg_begin(), iter_param); auto* paramAlloca = builder.CreateAlloca(iter_param->getType()); @@ -2364,7 +2372,7 @@ llvm::DIType* CodeGenLLVM::GetDebugType(const Type& ty_tir, llvm::Type* ty_llvm) } else if (ty_llvm->isPointerTy()) { auto* ptr_type = ty_tir.as(); - ICHECK(ptr_type != nullptr || GetRuntimeDataType(ty_tir).is_handle()) + TVM_FFI_ICHECK(ptr_type != nullptr || GetRuntimeDataType(ty_tir).is_handle()) << "Got LLVM pointer type from non-pointer IR type: " << ty_tir; auto* pointee_type = ptr_type != nullptr ? GetDebugType(ptr_type->element_type, GetLLVMType(ptr_type->element_type)) @@ -2396,7 +2404,7 @@ llvm::DIType* CodeGenLLVM::GetDebugType(const Type& ty_tir, llvm::Type* ty_llvm) std::string type_str; llvm::raw_string_ostream rso(type_str); ty_llvm->print(rso); - LOG(FATAL) << "Unknown LLVM type:" << rso.str(); + TVM_FFI_THROW(InternalError) << "Unknown LLVM type:" << rso.str(); } return nullptr; } diff --git a/src/target/llvm/codegen_nvptx.cc b/src/target/llvm/codegen_nvptx.cc index 17a90477d2fc..4069f20d7451 100644 --- a/src/target/llvm/codegen_nvptx.cc +++ b/src/target/llvm/codegen_nvptx.cc @@ -83,7 +83,7 @@ class CodeGenNVPTX : public CodeGenLLVM { } void VisitStmt_(const AllocateNode* op) final { - ICHECK(!is_zero(op->condition)); + TVM_FFI_ICHECK(!is_zero(op->condition)); llvm::Value* buf = nullptr; StorageInfo& info = alloc_storage_info_[op->buffer_var.get()]; // maximum necessary alignment in the NV devices @@ -98,7 +98,8 @@ class CodeGenNVPTX : public CodeGenLLVM { AllocateSharedMemory(op->dtype, 0, 3, info.alignment, llvm::GlobalValue::ExternalLinkage); } else { size_t constant_size = op->ConstantAllocationSize(); - ICHECK_GT(constant_size, 0) << "Can only handle constant size stack allocation in GPU"; + TVM_FFI_ICHECK_GT(constant_size, 0) + << "Can only handle constant size stack allocation in GPU"; if (constant_size % 4 == 0 && info.alignment == 0) { info.alignment = GetTempAllocaAlignment(op->dtype, constant_size); @@ -123,7 +124,7 @@ class CodeGenNVPTX : public CodeGenLLVM { } buf = alloca; } else { - ICHECK(storage_scope.rank == runtime::StorageRank::kShared) + TVM_FFI_ICHECK(storage_scope.rank == runtime::StorageRank::kShared) << "Can only allocate shared or local memory inside kernel"; buf = AllocateSharedMemory(op->dtype, constant_size, 3, info.alignment, llvm::GlobalValue::ExternalLinkage); @@ -133,7 +134,7 @@ class CodeGenNVPTX : public CodeGenLLVM { buf = builder_->CreatePointerCast( buf, llvmGetPointerTo(DTypeToLLVMType(op->dtype), buf->getType()->getPointerAddressSpace())); - ICHECK(!var_map_.count(op->buffer_var.get())); + TVM_FFI_ICHECK(!var_map_.count(op->buffer_var.get())); var_map_[op->buffer_var.get()] = buf; this->VisitStmt(op->body); } @@ -154,10 +155,10 @@ class CodeGenNVPTX : public CodeGenLLVM { intrin_id = llvm::Intrinsic::nvvm_read_ptx_sreg_tid_z; break; default: - LOG(FATAL) << "unknown thread idx"; + TVM_FFI_THROW(InternalError) << "unknown thread idx"; } } else { - ICHECK_EQ(ts.rank, 0); + TVM_FFI_ICHECK_EQ(ts.rank, 0); switch (ts.dim_index) { case 0: intrin_id = llvm::Intrinsic::nvvm_read_ptx_sreg_ctaid_x; @@ -169,7 +170,7 @@ class CodeGenNVPTX : public CodeGenLLVM { intrin_id = llvm::Intrinsic::nvvm_read_ptx_sreg_ctaid_z; break; default: - LOG(FATAL) << "unknown thread idx"; + TVM_FFI_THROW(InternalError) << "unknown thread idx"; } } #if TVM_LLVM_VERSION >= 200 @@ -200,7 +201,7 @@ class CodeGenNVPTX : public CodeGenLLVM { #endif return builder_->CreateCall(f, {}); } else { - LOG(FATAL) << "Do not support sync " << sync; + TVM_FFI_THROW(InternalError) << "Do not support sync " << sync; } } @@ -288,7 +289,7 @@ llvm::Value* CodeGenNVPTX::CreateIntrinsic(const CallNode* op) { auto val = llvm::InlineAsm::get(fty, "activemask.b32 %0", "=r", true); return builder_->CreateCall(val); } else if (op->op.same_as(builtin::atomic_add())) { - ICHECK(op->args[1]->dtype.bits() == 32) << "Only supports 32 bit atomic for now"; + TVM_FFI_ICHECK(op->args[1]->dtype.bits() == 32) << "Only supports 32 bit atomic for now"; llvm::Value* v0 = MakeValue(op->args[0]); llvm::Value* v1 = MakeValue(op->args[1]); if (op->args[1]->dtype.is_float()) { @@ -301,7 +302,7 @@ llvm::Value* CodeGenNVPTX::CreateIntrinsic(const CallNode* op) { llvm::AtomicOrdering::Monotonic); #endif #else - LOG(FATAL) << "Floating point atomic requires LLVM 9 or newer"; + TVM_FFI_THROW(InternalError) << "Floating point atomic requires LLVM 9 or newer"; #endif } #if TVM_LLVM_VERSION >= 130 @@ -317,7 +318,7 @@ llvm::Value* CodeGenNVPTX::CreateIntrinsic(const CallNode* op) { int GetCUDAComputeVersion(const Target& target) { ffi::Optional mcpu = target->GetAttr("mcpu"); - ICHECK(mcpu.has_value()) << "InternalError: \"-mcpu\" is undefined in the NVPTX target"; + TVM_FFI_CHECK(mcpu.has_value(), InternalError) << "\"-mcpu\" is undefined in the NVPTX target"; std::string sm_version = mcpu.value(); return std::stoi(sm_version.substr(3)); } @@ -359,17 +360,19 @@ ffi::Module BuildNVPTX(IRModule mod, Target target) { // emit ptx llvm::legacy::PassManager pass; #if TVM_LLVM_VERSION <= 60 - ICHECK(tm->addPassesToEmitFile(pass, dest_ptx, llvm::TargetMachine::CGFT_AssemblyFile) == 0) + TVM_FFI_ICHECK(tm->addPassesToEmitFile(pass, dest_ptx, llvm::TargetMachine::CGFT_AssemblyFile) == + 0) << "Cannot emit target CGFT_ObjectFile"; #elif TVM_LLVM_VERSION <= 90 - ICHECK(tm->addPassesToEmitFile(pass, dest_ptx, nullptr, llvm::TargetMachine::CGFT_AssemblyFile) == - 0) + TVM_FFI_ICHECK( + tm->addPassesToEmitFile(pass, dest_ptx, nullptr, llvm::TargetMachine::CGFT_AssemblyFile) == 0) << "Cannot emit target CGFT_ObjectFile"; #elif TVM_LLVM_VERSION <= 170 - ICHECK(tm->addPassesToEmitFile(pass, dest_ptx, nullptr, llvm::CGFT_AssemblyFile) == 0) + TVM_FFI_ICHECK(tm->addPassesToEmitFile(pass, dest_ptx, nullptr, llvm::CGFT_AssemblyFile) == 0) << "Cannot emit target CGFT_ObjectFile"; #else - ICHECK(tm->addPassesToEmitFile(pass, dest_ptx, nullptr, llvm::CodeGenFileType::AssemblyFile) == 0) + TVM_FFI_ICHECK( + tm->addPassesToEmitFile(pass, dest_ptx, nullptr, llvm::CodeGenFileType::AssemblyFile) == 0) << "Cannot emit target CodeGenFileType::ObjectFile"; #endif pass.run(*module); diff --git a/src/target/llvm/codegen_params.cc b/src/target/llvm/codegen_params.cc index e2e5323445c8..fccc92a22830 100644 --- a/src/target/llvm/codegen_params.cc +++ b/src/target/llvm/codegen_params.cc @@ -74,10 +74,11 @@ llvm::ConstantArray* TensorToLLVMArray(llvm::LLVMContext* ctx, ::tvm::runtime::T llvm::Type* element_type = nullptr; auto arr_type = arr.DataType(); - CHECK(arr.IsContiguous()) << "CodegenParams: only support contiguous arrays"; - CHECK_EQ(arr->device.device_type, kDLCPU) << "CodegenParams: only support contiguous arrays"; - CHECK_EQ(arr_type.lanes(), 1) << "CodegenParams: only support generating 1-lane parameters; saw " - << arr_type.lanes(); + TVM_FFI_ICHECK(arr.IsContiguous()) << "CodegenParams: only support contiguous arrays"; + TVM_FFI_ICHECK_EQ(arr->device.device_type, kDLCPU) + << "CodegenParams: only support contiguous arrays"; + TVM_FFI_ICHECK_EQ(arr_type.lanes(), 1) + << "CodegenParams: only support generating 1-lane parameters; saw " << arr_type.lanes(); auto shape = arr.Shape(); int num_elements = 1; @@ -89,8 +90,8 @@ llvm::ConstantArray* TensorToLLVMArray(llvm::LLVMContext* ctx, ::tvm::runtime::T switch (arr_type.code()) { case runtime::DataType::kInt: - CHECK(arr_type.bits() == 8 || arr_type.bits() == 16 || arr_type.bits() == 32 || - arr_type.bits() == 64) + TVM_FFI_ICHECK(arr_type.bits() == 8 || arr_type.bits() == 16 || arr_type.bits() == 32 || + arr_type.bits() == 64) << "CodegenParams: only support generating 8-, 16-, 32-, or 64-bit integer params; saw " << arr_type.bits() << "-bit array"; element_type = llvm::Type::getIntNTy(*ctx, arr_type.bits()); @@ -109,14 +110,14 @@ llvm::ConstantArray* TensorToLLVMArray(llvm::LLVMContext* ctx, ::tvm::runtime::T BuildLLVMVector(element_type, arr->data, num_elements, &elements); break; default: - ICHECK(false) << "should not get here"; + TVM_FFI_ICHECK(false) << "should not get here"; break; } break; case runtime::DataType::TypeCode::kUInt: - CHECK(arr_type.bits() == 8 || arr_type.bits() == 16 || arr_type.bits() == 32 || - arr_type.bits() == 64) + TVM_FFI_ICHECK(arr_type.bits() == 8 || arr_type.bits() == 16 || arr_type.bits() == 32 || + arr_type.bits() == 64) << "CodegenParams: only support generating 8-, 16-, 32-, or 64-bit integer params; saw " << arr_type.bits() << "-bit array"; element_type = llvm::Type::getIntNTy(*ctx, arr_type.bits()); @@ -135,7 +136,7 @@ llvm::ConstantArray* TensorToLLVMArray(llvm::LLVMContext* ctx, ::tvm::runtime::T BuildLLVMVector(element_type, arr->data, num_elements, &elements); break; default: - ICHECK(false) << "should not get here"; + TVM_FFI_ICHECK(false) << "should not get here"; break; } break; @@ -156,20 +157,20 @@ llvm::ConstantArray* TensorToLLVMArray(llvm::LLVMContext* ctx, ::tvm::runtime::T BuildLLVMVector(element_type, arr->data, num_elements, &elements); break; default: - CHECK(false) << "CodegenParams: only support 32- or 64-bit floating point; saw " - << arr_type.bits() << "-bit array"; + TVM_FFI_ICHECK(false) << "CodegenParams: only support 32- or 64-bit floating point; saw " + << arr_type.bits() << "-bit array"; break; } break; case runtime::DataType::TypeCode::kBFloat: - CHECK(arr_type.bits() == 16) + TVM_FFI_ICHECK(arr_type.bits() == 16) << "CodegenParams: only support 16-bit bfloat; saw " << arr_type.bits() << "-bit array"; element_type = llvm::Type::getIntNTy(*ctx, arr_type.bits()); BuildLLVMVector(element_type, arr->data, num_elements, &elements); default: - CHECK(false) << "Data type not supported"; + TVM_FFI_ICHECK(false) << "Data type not supported"; } return llvm::cast(llvm::ConstantArray::get( diff --git a/src/target/llvm/codegen_x86_64.cc b/src/target/llvm/codegen_x86_64.cc index 2666a3dc1c40..cd280d2ddc6f 100644 --- a/src/target/llvm/codegen_x86_64.cc +++ b/src/target/llvm/codegen_x86_64.cc @@ -58,7 +58,7 @@ llvm::Value* CodeGenX86_64::VisitExpr_(const CastNode* op) { const auto from = op->value.dtype(); const auto to = op->dtype; if (from.is_float() && to.is_float() && from.bits() == 16 && to.bits() == 32) { - ICHECK_EQ(from.lanes(), to.lanes()); + TVM_FFI_ICHECK_EQ(from.lanes(), to.lanes()); const auto has_avx512 = llvm_target_->TargetHasCPUFeature("avx512f"); @@ -111,13 +111,13 @@ llvm::Value* CodeGenX86_64::CallVectorIntrin(llvm::Intrinsic::ID id, size_t intr // Otherwise, we split the vector into intrin_lanes sized elements (widening where necessary), // compute each result, and then concatenate the vectors (slicing the result if necessary). - ICHECK_LT(intrin_lanes, num_elems); + TVM_FFI_ICHECK_LT(intrin_lanes, num_elems); std::vector split_results; for (size_t i = 0; i < num_elems; i += intrin_lanes) { std::vector split_args; for (const auto& v : args) { if (v->getType()->isVectorTy()) { - ICHECK_EQ(GetVectorNumElements(v), num_elems); + TVM_FFI_ICHECK_EQ(GetVectorNumElements(v), num_elems); split_args.push_back(CreateVecSlice(v, i, intrin_lanes)); } else { split_args.push_back(v); diff --git a/src/target/llvm/intrin_rule_hexagon.cc b/src/target/llvm/intrin_rule_hexagon.cc index bb78af0a8434..dc71e69a2122 100644 --- a/src/target/llvm/intrin_rule_hexagon.cc +++ b/src/target/llvm/intrin_rule_hexagon.cc @@ -50,12 +50,12 @@ template inline PrimExpr DispatchTVMQHLWrapperFp16(const PrimExpr& e) { using namespace tir; const CallNode* call = e.as(); - ICHECK(call != nullptr); + TVM_FFI_ICHECK(call != nullptr); ffi::Array new_args; #if ENABLE_QHL // Check target for qfloat enablement const auto f = tvm::ffi::Function::GetGlobal("target.TargetCurrent"); - ICHECK(f.has_value()) << "target.TargetCurrent is not registered"; + TVM_FFI_ICHECK(f.has_value()) << "target.TargetCurrent is not registered"; const auto ret = (*f)(true); bool useqhl = true; if (auto opt_target = ret.as()) { @@ -99,13 +99,13 @@ TVM_REGISTER_OP("tir.ctpop") TVM_REGISTER_OP("tir.tanh") .set_attr("hexagon.FLowerIntrinsic", [](const PrimExpr& e) { const tir::CallNode* call = e.as(); - ICHECK(call != nullptr); + TVM_FFI_ICHECK(call != nullptr); const PrimExpr& x = call->args[0]; #if ENABLE_QHL // Check target for qfloat enablement const auto f = tvm::ffi::Function::GetGlobal("target.TargetCurrent"); - ICHECK(f.has_value()) << "target.TargetCurrent is not registered"; + TVM_FFI_ICHECK(f.has_value()) << "target.TargetCurrent is not registered"; const auto ret = (*f)(true); bool useqhl = true; if (auto opt_target = ret.as()) { @@ -135,12 +135,12 @@ TVM_REGISTER_OP("tir.tanh") TVM_REGISTER_OP("tir.tan").set_attr( "hexagon.FLowerIntrinsic", [](const PrimExpr& e) { const tir::CallNode* call = e.as(); - ICHECK(call != nullptr); + TVM_FFI_ICHECK(call != nullptr); const PrimExpr& x = call->args[0]; #if ENABLE_QHL // Check target for qfloat enablement const auto f = tvm::ffi::Function::GetGlobal("target.TargetCurrent"); - ICHECK(f.has_value()) << "target.TargetCurrent is not registered"; + TVM_FFI_ICHECK(f.has_value()) << "target.TargetCurrent is not registered"; const auto ret = (*f)(true); bool useqhl = true; if (auto opt_target = ret.as()) { @@ -165,12 +165,12 @@ TVM_REGISTER_OP("tir.nearbyint") TVM_REGISTER_OP("tir.sigmoid") .set_attr("hexagon.FLowerIntrinsic", [](const PrimExpr& e) { const tir::CallNode* call = e.as(); - ICHECK(call != nullptr); + TVM_FFI_ICHECK(call != nullptr); const PrimExpr& x = call->args[0]; #if ENABLE_QHL // Check target for qfloat enablement const auto f = tvm::ffi::Function::GetGlobal("target.TargetCurrent"); - ICHECK(f.has_value()) << "target.TargetCurrent is not registered"; + TVM_FFI_ICHECK(f.has_value()) << "target.TargetCurrent is not registered"; const auto ret = (*f)(true); bool useqhl = true; if (auto opt_target = ret.as()) { diff --git a/src/target/llvm/intrin_rule_llvm.cc b/src/target/llvm/intrin_rule_llvm.cc index a8a3d911ca8e..4406a5949052 100644 --- a/src/target/llvm/intrin_rule_llvm.cc +++ b/src/target/llvm/intrin_rule_llvm.cc @@ -119,7 +119,7 @@ TVM_REGISTER_OP("tir.exp10") using tir::make_const; using tir::make_zero; const tir::CallNode* call = e.as(); - ICHECK(call != nullptr); + TVM_FFI_ICHECK(call != nullptr); const PrimExpr& x = call->args[0]; PrimExpr ln10 = make_const(x.dtype(), 2.302585093); PrimExpr ret = exp(x * ln10); @@ -128,7 +128,7 @@ TVM_REGISTER_OP("tir.exp10") TVM_REGISTER_OP("tir.tan").set_attr("llvm.FLegalize", [](const PrimExpr& e) -> PrimExpr { const tir::CallNode* call = e.as(); - ICHECK(call != nullptr); + TVM_FFI_ICHECK(call != nullptr); const PrimExpr& x = call->args[0]; PrimExpr tan_x = sin(x) / cos(x); return tan_x; @@ -139,7 +139,7 @@ TVM_REGISTER_OP("tir.cosh") using tir::make_const; using tir::make_zero; const tir::CallNode* call = e.as(); - ICHECK(call != nullptr); + TVM_FFI_ICHECK(call != nullptr); const PrimExpr& x = call->args[0]; PrimExpr two = make_const(x.dtype(), 2); PrimExpr neg_one = make_const(x.dtype(), -1); @@ -154,7 +154,7 @@ TVM_REGISTER_OP("tir.sinh") using tir::make_const; using tir::make_zero; const tir::CallNode* call = e.as(); - ICHECK(call != nullptr); + TVM_FFI_ICHECK(call != nullptr); const PrimExpr& x = call->args[0]; PrimExpr two = make_const(x.dtype(), 2); PrimExpr neg_one = make_const(x.dtype(), -1); @@ -169,7 +169,7 @@ TVM_REGISTER_OP("tir.asin") using tir::make_const; using namespace intrin; const tir::CallNode* call = e.as(); - ICHECK(call != nullptr); + TVM_FFI_ICHECK(call != nullptr); const PrimExpr& x = call->args[0]; PrimExpr threshold = make_const(x.dtype(), 0.5); @@ -201,7 +201,7 @@ TVM_REGISTER_OP("tir.acos") using tir::make_const; using namespace intrin; const tir::CallNode* call = e.as(); - ICHECK(call != nullptr) << "Invalid call node in acos legalization"; + TVM_FFI_ICHECK(call != nullptr) << "Invalid call node in acos legalization"; const PrimExpr& x = call->args[0]; PrimExpr threshold = make_const(x.dtype(), 0.5); @@ -227,7 +227,7 @@ TVM_REGISTER_OP("tir.atan") .set_attr("llvm.FLegalize", [](const PrimExpr& e) -> PrimExpr { using tir::make_const; const tir::CallNode* call = e.as(); - ICHECK(call != nullptr) << "Invalid call node in atan legalization"; + TVM_FFI_ICHECK(call != nullptr) << "Invalid call node in atan legalization"; const PrimExpr& x = call->args[0]; PrimExpr one = make_const(x.dtype(), 1.0); PrimExpr denom = sqrt(x * x + one); @@ -238,7 +238,7 @@ TVM_REGISTER_OP("tir.asinh") .set_attr("llvm.FLegalize", [](const PrimExpr& e) -> PrimExpr { using tir::make_const; const tir::CallNode* call = e.as(); - ICHECK(call != nullptr) << "Invalid call node in asinh legalization"; + TVM_FFI_ICHECK(call != nullptr) << "Invalid call node in asinh legalization"; const PrimExpr& x = call->args[0]; PrimExpr one = make_const(x.dtype(), 1.0); PrimExpr sqrt_val = sqrt(x * x + one); @@ -249,7 +249,7 @@ TVM_REGISTER_OP("tir.acosh") .set_attr("llvm.FLegalize", [](const PrimExpr& e) -> PrimExpr { using tir::make_const; const tir::CallNode* call = e.as(); - ICHECK(call != nullptr) << "Invalid call node in acosh legalization"; + TVM_FFI_ICHECK(call != nullptr) << "Invalid call node in acosh legalization"; const PrimExpr& x = call->args[0]; PrimExpr one = make_const(x.dtype(), 1.0); PrimExpr sqrt_val = sqrt(x * x - one); @@ -260,7 +260,7 @@ TVM_REGISTER_OP("tir.atanh") .set_attr("llvm.FLegalize", [](const PrimExpr& e) -> PrimExpr { using tir::make_const; const tir::CallNode* call = e.as(); - ICHECK(call != nullptr) << "Invalid call node in atanh legalization"; + TVM_FFI_ICHECK(call != nullptr) << "Invalid call node in atanh legalization"; const PrimExpr& x = call->args[0]; PrimExpr one = make_const(x.dtype(), 1.0); return (log(one + x) - log(one - x)) * make_const(x.dtype(), 0.5); @@ -269,7 +269,7 @@ TVM_REGISTER_OP("tir.atanh") TVM_REGISTER_OP("tir.erf").set_attr("llvm.FLegalize", [](const PrimExpr& e) -> PrimExpr { using tir::make_const; const tir::CallNode* call = e.as(); - ICHECK(call != nullptr) << "Invalid call node in erf legalization"; + TVM_FFI_ICHECK(call != nullptr) << "Invalid call node in erf legalization"; const PrimExpr& x = call->args[0]; PrimExpr abs_x = tvm::abs(x); PrimExpr t = make_const(x.dtype(), 1.0) / @@ -286,8 +286,8 @@ TVM_REGISTER_OP("tir.erf").set_attr("llvm.FLegalize", [](const PrimEx TVM_REGISTER_OP("tir.clz").set_attr("llvm.FLegalize", [](const PrimExpr& e) -> PrimExpr { const tir::CallNode* call = e.as(); - ICHECK(call != nullptr); - ICHECK_EQ(call->args.size(), 1); + TVM_FFI_ICHECK(call != nullptr); + TVM_FFI_ICHECK_EQ(call->args.size(), 1); ffi::Array cargs; cargs.push_back(IntImm(DataType::UInt(32), ::llvm::Intrinsic::ctlz)); cargs.push_back(call->args[0]); diff --git a/src/target/llvm/intrin_rule_llvm.h b/src/target/llvm/intrin_rule_llvm.h index 445d33522c7e..f1bed6378060 100644 --- a/src/target/llvm/intrin_rule_llvm.h +++ b/src/target/llvm/intrin_rule_llvm.h @@ -40,11 +40,11 @@ namespace codegen { template inline PrimExpr DispatchLLVMPureIntrin(const PrimExpr& e) { const tir::CallNode* call = e.as(); - ICHECK(call != nullptr); + TVM_FFI_ICHECK(call != nullptr); ffi::Array cargs; // intrin id. cargs.push_back(IntImm(DataType::UInt(32), id)); - ICHECK_EQ(call->args.size(), num_signature) + TVM_FFI_ICHECK_EQ(call->args.size(), num_signature) << "llvm.call_llvm_intrin" << llvmGetIntrinName(id) << "expects " << num_signature << " arguments, but got " << call->args.size(); @@ -57,11 +57,11 @@ inline PrimExpr DispatchLLVMPureIntrin(const PrimExpr& e) { template inline PrimExpr DispatchLLVMIntrin(const PrimExpr& e) { const tir::CallNode* call = e.as(); - ICHECK(call != nullptr); + TVM_FFI_ICHECK(call != nullptr); ffi::Array cargs; // intrin id. cargs.push_back(IntImm(DataType::UInt(32), id)); - ICHECK_EQ(call->args.size(), num_signature) + TVM_FFI_ICHECK_EQ(call->args.size(), num_signature) << "llvm.call_llvm_intrin" << llvmGetIntrinName(id) << "expects " << num_signature << " arguments, but got " << call->args.size(); for (PrimExpr arg : call->args) { diff --git a/src/target/llvm/intrin_rule_nvptx.cc b/src/target/llvm/intrin_rule_nvptx.cc index a5fef4f5d411..42f1352e36f3 100644 --- a/src/target/llvm/intrin_rule_nvptx.cc +++ b/src/target/llvm/intrin_rule_nvptx.cc @@ -36,14 +36,14 @@ namespace codegen { inline PrimExpr DispatchPureExternLibDevice(const PrimExpr& e) { using namespace tir; const CallNode* call = e.as(); - ICHECK(call != nullptr); - ICHECK(call->dtype.bits() == 32 || call->dtype.bits() == 64) + TVM_FFI_ICHECK(call != nullptr); + TVM_FFI_ICHECK(call->dtype.bits() == 32 || call->dtype.bits() == 64) << "Only support float32 or float64."; const OpNode* op = call->op.as(); - ICHECK(op != nullptr); + TVM_FFI_ICHECK(op != nullptr); std::string name = op->name; - ICHECK_EQ(name.substr(0, 4), "tir."); + TVM_FFI_ICHECK_EQ(name.substr(0, 4), "tir."); std::ostringstream intrinsic_name; intrinsic_name << "__nv_" << name.substr(4); diff --git a/src/target/llvm/intrin_rule_rocm.cc b/src/target/llvm/intrin_rule_rocm.cc index d4c92a38d1ba..eec2cf2d1dc0 100644 --- a/src/target/llvm/intrin_rule_rocm.cc +++ b/src/target/llvm/intrin_rule_rocm.cc @@ -42,12 +42,12 @@ inline PrimExpr DispatchPureExternOCML(const PrimExpr& e) { // extreme caution. using namespace tir; const CallNode* call = e.as(); - ICHECK(call != nullptr); + TVM_FFI_ICHECK(call != nullptr); const OpNode* op = call->op.as(); - ICHECK(op != nullptr); + TVM_FFI_ICHECK(op != nullptr); std::string name = op->name; - ICHECK_EQ(name.substr(0, 4), "tir."); + TVM_FFI_ICHECK_EQ(name.substr(0, 4), "tir."); std::ostringstream intrinsic_name; intrinsic_name << "__ocml_" << name.substr(4) << "_f" << call->dtype.bits(); @@ -63,10 +63,10 @@ inline PrimExpr DispatchPureExternOCML(const PrimExpr& e) { inline PrimExpr DispatchShuffle(const PrimExpr& e) { using namespace tir; const CallNode* call = e.as(); - ICHECK(call != nullptr); - ICHECK_EQ(call->args.size(), 5); // mask, value, warp_id, width, warp_size + TVM_FFI_ICHECK(call != nullptr); + TVM_FFI_ICHECK_EQ(call->args.size(), 5); // mask, value, warp_id, width, warp_size PrimExpr var = call->args[1]; - ICHECK_EQ(var.dtype().bits(), 32); + TVM_FFI_ICHECK_EQ(var.dtype().bits(), 32); // get own lane in self (__lane_id) PrimExpr minus_one = tir::make_const(DataType::Int(32), -1); @@ -87,7 +87,7 @@ inline PrimExpr DispatchShuffle(const PrimExpr& e) { index = self - delta; index = Select(index < (self & ~(width - 1)), self, index); } else { - ICHECK(call->op.same_as(builtin::tvm_warp_shuffle_down())); + TVM_FFI_ICHECK(call->op.same_as(builtin::tvm_warp_shuffle_down())); PrimExpr delta = call->args[2]; index = self + delta; index = Select((self & (width - 1)) + delta >= width, self, index); diff --git a/src/target/llvm/llvm_instance.cc b/src/target/llvm/llvm_instance.cc index 4e8b072f01cb..c9e855c974a3 100644 --- a/src/target/llvm/llvm_instance.cc +++ b/src/target/llvm/llvm_instance.cc @@ -161,7 +161,7 @@ std::unique_ptr LLVMInstance::LoadIR(const std::string& file_name) llvm::ErrorOr> maybe_buffer = llvm::MemoryBuffer::getFileAsStream(file_name); if (std::error_code ec = maybe_buffer.getError()) { - LOG(FATAL) << ec.message(); + TVM_FFI_THROW(InternalError) << ec.message(); } return ParseBuffer(**maybe_buffer); } @@ -173,7 +173,7 @@ std::unique_ptr LLVMInstance::ParseBuffer(const llvm::MemoryBuffer std::string message; llvm::raw_string_ostream ostream(message); error.print(/*ProgName=*/nullptr, ostream, /*ShowColors=*/false, /*ShowKindLabel=*/true); - LOG(FATAL) << ostream.str(); + TVM_FFI_THROW(InternalError) << ostream.str(); } return module; @@ -249,7 +249,7 @@ LLVMTargetInfo::LLVMTargetInfo(LLVMInstance& instance, LOG(ERROR) << "\"" << opt.name << "\" is not an LLVM option, option ignored"; } } - ICHECK(!parse_error) << "there were errors parsing command-line options"; + TVM_FFI_ICHECK(!parse_error) << "there were errors parsing command-line options"; } llvm::FloatABI::ABIType float_abi = llvm::FloatABI::Default; @@ -260,7 +260,7 @@ LLVMTargetInfo::LLVMTargetInfo(LLVMInstance& instance, } else if (value == "soft") { float_abi = llvm::FloatABI::Soft; } else { - LOG(FATAL) << "invalid -mfloat-abi option " << value; + TVM_FFI_THROW(InternalError) << "invalid -mfloat-abi option " << value; } } @@ -270,7 +270,8 @@ LLVMTargetInfo::LLVMTargetInfo(LLVMInstance& instance, if ((value == "mcjit") || (value == "orcjit")) { jit_engine_ = value; } else { - LOG(FATAL) << "invalid jit option " << value << " (can be `orcjit` or `mcjit`)."; + TVM_FFI_THROW(InternalError) + << "invalid jit option " << value << " (can be `orcjit` or `mcjit`)."; } } @@ -279,7 +280,7 @@ LLVMTargetInfo::LLVMTargetInfo(LLVMInstance& instance, Downcast>(target.Get("vector-width").value_or(nullptr))) { vector_width_ = w.value(); if ((vector_width_ <= 0) || (vector_width_ > 65536)) { - LOG(FATAL) << "Invalid -vector-width value: " << vector_width_; + TVM_FFI_THROW(InternalError) << "Invalid -vector-width value: " << vector_width_; } } @@ -417,7 +418,7 @@ static const llvm::Target* CreateLLVMTargetInstance(const std::string triple, // required mimimum: llvm::InitializeAllTargets() const llvm::Target* llvm_instance = llvm::TargetRegistry::lookupTarget(triple, error); if (!allow_missing && !llvm_instance) { - ICHECK(llvm_instance) << "LLVM instance error: `" << error << "`"; + TVM_FFI_ICHECK(llvm_instance) << "LLVM instance error: `" << error << "`"; } return llvm_instance; @@ -435,7 +436,7 @@ static std::unique_ptr CreateLLVMTargetMachine( #endif llvm::TargetMachine* tm = llvm_instance->createTargetMachine( triple, cpu, features, target_options, reloc_model, code_model, opt_level); - ICHECK(tm != nullptr); + TVM_FFI_ICHECK(tm != nullptr); return std::unique_ptr(tm); } @@ -449,7 +450,7 @@ llvm::TargetMachine* LLVMTargetInfo::GetOrCreateTargetMachine(bool allow_missing CreateLLVMTargetMachine(llvm_instance, triple_, cpu_, GetTargetFeatureString(), target_options_, reloc_model_, code_model_, opt_level_); } - ICHECK(target_machine_ != nullptr); + TVM_FFI_ICHECK(target_machine_ != nullptr); return target_machine_.get(); } @@ -673,7 +674,7 @@ LLVMTargetInfo::Option LLVMTargetInfo::ParseOptionString(const std::string& str) part_this++; // Only advance if we saw ":". if (part_this < part_end) { auto& p1 = parts[part_this]; - ICHECK(!p1.empty()) << "tokenizing error"; // This shouldn't happen. + TVM_FFI_ICHECK(!p1.empty()) << "tokenizing error"; // This shouldn't happen. if (p1 != "=") { part_this++; if (p1 == "bool") { @@ -791,7 +792,7 @@ LLVMTargetInfo::Option LLVMTargetInfo::ParseOptionString(const std::string& str) } } - ICHECK(type != Option::OptType::Invalid); + TVM_FFI_ICHECK(type != Option::OptType::Invalid); opt.type = type; return opt; } @@ -800,7 +801,7 @@ bool LLVMTargetInfo::MatchesGlobalState() const { for (const Option& opt : GetCommandLineOptions()) { Option current_opt = opt; GetOptionValue(¤t_opt); - ICHECK(current_opt.type != Option::OptType::Invalid); + TVM_FFI_ICHECK(current_opt.type != Option::OptType::Invalid); switch (current_opt.type) { case Option::OptType::Bool: if (current_opt.value.b != opt.value.b) return false; @@ -953,7 +954,7 @@ LLVMTarget::LLVMTarget(LLVMInstance& instance, const LLVMTargetInfo& target_info } if (modified_llvm_state_) { - ICHECK(!ApplyLLVMOptions(true)); + TVM_FFI_ICHECK(!ApplyLLVMOptions(true)); } else { modified_llvm_state_ = ApplyLLVMOptions(true); } @@ -973,7 +974,7 @@ LLVMTarget::~LLVMTarget() { } llvm::LLVMContext* LLVMTarget::GetContext() const { - ICHECK(!ctx_.expired()) << "LLVM scope has been deleted"; + TVM_FFI_ICHECK(!ctx_.expired()) << "LLVM scope has been deleted"; return ctx_.lock().get(); } @@ -1038,7 +1039,7 @@ bool LLVMTarget::ApplyLLVMOptions(bool apply_otherwise_revert, bool dry_run) { auto* str_op = static_cast*>(base_op); HANDLE_OPTION_VALUE(str_op, new_opt.value.s, saved_opt.value.s); } else { - LOG(FATAL) << "unexpected type in option " << new_opt; + TVM_FFI_THROW(InternalError) << "unexpected type in option " << new_opt; } if (dry_run && changed) { diff --git a/src/target/llvm/llvm_module.cc b/src/target/llvm/llvm_module.cc index 460e3e6f9f5e..4294c77c7e08 100644 --- a/src/target/llvm/llvm_module.cc +++ b/src/target/llvm/llvm_module.cc @@ -149,7 +149,7 @@ LLVMModuleNode::~LLVMModuleNode() { auto dtorRunner = std::make_unique(orcjit_ee_->getMainJITDylib()); dtorRunner->add(dtors); auto err = dtorRunner->run(); - ICHECK(!err) << llvm::toString(std::move(err)); + TVM_FFI_ICHECK(!err) << llvm::toString(std::move(err)); orcjit_ee_.reset(); } module_owning_ptr_.reset(); @@ -181,7 +181,7 @@ ffi::Optional LLVMModuleNode::GetFunction(const ffi::String& name return ffi::Function( [target_string](ffi::PackedArgs args, ffi::Any* rv) { *rv = target_string; }); } - ICHECK(jit_engine_.size()) << "JIT engine type is missing"; + TVM_FFI_ICHECK(jit_engine_.size()) << "JIT engine type is missing"; if ((jit_engine_ == "mcjit") && (mcjit_ee_ == nullptr)) InitMCJIT(); if ((jit_engine_ == "orcjit") && (orcjit_ee_ == nullptr)) InitORCJIT(); @@ -238,12 +238,13 @@ bool LLVMAddPassesToEmitFile(llvm::TargetMachine* tm, llvm::legacy::PassManager* void LLVMModuleNode::WriteToFile(const ffi::String& file_name_str, const ffi::String& format) const { - // CHECK(imports_.empty()) << "SaveToFile does not handle imported modules"; + // TVM_FFI_ICHECK(imports_.empty()) << "SaveToFile does not handle imported modules"; std::string file_name = file_name_str; std::string fmt = runtime::GetFileFormat(file_name, format); std::error_code ecode; llvm::raw_fd_ostream dest(file_name, ecode, llvm_open_output_flag); - ICHECK_EQ(ecode.value(), 0) << "Cannot open file: " << file_name << " " << ecode.message(); + TVM_FFI_ICHECK_EQ(ecode.value(), 0) + << "Cannot open file: " << file_name << " " << ecode.message(); bool is_obj_file = fmt == "o" || fmt == "obj"; bool is_asm_file = fmt == "s" || fmt == "asm"; if (is_obj_file || is_asm_file) { @@ -254,7 +255,7 @@ void LLVMModuleNode::WriteToFile(const ffi::String& file_name_str, llvm::TargetMachine* tm = llvm_target->GetOrCreateTargetMachine(); auto err = LLVMAddPassesToEmitFile(tm, &pass, &dest, llvm_file_target); - ICHECK(!err) << "Cannot emit target CGFT_ObjectFile"; + TVM_FFI_ICHECK(!err) << "Cannot emit target CGFT_ObjectFile"; pass.run(*CloneLLVMModule(module_)); } else if (fmt == "ll") { @@ -266,14 +267,14 @@ void LLVMModuleNode::WriteToFile(const ffi::String& file_name_str, llvm::WriteBitcodeToFile(*module_, dest); #endif } else { - LOG(FATAL) << "Do not know how to save file " << file_name << " with format=\'" << format - << "\'"; + TVM_FFI_THROW(InternalError) << "Do not know how to save file " << file_name + << " with format=\'" << format << "\'"; } dest.close(); } ffi::Bytes LLVMModuleNode::SaveToBytes() const { - LOG(FATAL) << "LLVMModule: SaveToBytes not supported"; + TVM_FFI_THROW(InternalError) << "LLVMModule: SaveToBytes not supported"; } ffi::String LLVMModuleNode::InspectSource(const ffi::String& format) const { @@ -292,16 +293,18 @@ ffi::String LLVMModuleNode::InspectSource(const ffi::String& format) const { llvm::legacy::PassManager pass; llvm::TargetMachine* tm = llvm_target->GetOrCreateTargetMachine(); #if TVM_LLVM_VERSION <= 60 - ICHECK(tm->addPassesToEmitFile(pass, rso, llvm::TargetMachine::CGFT_AssemblyFile) == 0) + TVM_FFI_ICHECK(tm->addPassesToEmitFile(pass, rso, llvm::TargetMachine::CGFT_AssemblyFile) == 0) << "Cannot emit target CGFT_AssemblyFile"; #elif TVM_LLVM_VERSION <= 90 - ICHECK(tm->addPassesToEmitFile(pass, rso, nullptr, llvm::TargetMachine::CGFT_AssemblyFile) == 0) + TVM_FFI_ICHECK( + tm->addPassesToEmitFile(pass, rso, nullptr, llvm::TargetMachine::CGFT_AssemblyFile) == 0) << "Cannot emit target CGFT_AssemblyFile"; #elif TVM_LLVM_VERSION <= 170 - ICHECK(tm->addPassesToEmitFile(pass, rso, nullptr, llvm::CGFT_AssemblyFile) == 0) + TVM_FFI_ICHECK(tm->addPassesToEmitFile(pass, rso, nullptr, llvm::CGFT_AssemblyFile) == 0) << "Cannot emit target CGFT_AssemblyFile"; #else - ICHECK(tm->addPassesToEmitFile(pass, rso, nullptr, llvm::CodeGenFileType::AssemblyFile) == 0) + TVM_FFI_ICHECK( + tm->addPassesToEmitFile(pass, rso, nullptr, llvm::CodeGenFileType::AssemblyFile) == 0) << "Cannot emit target CodeGenFileType::AssemblyFile"; #endif pass.run(*m); @@ -309,11 +312,12 @@ ffi::String LLVMModuleNode::InspectSource(const ffi::String& format) const { } else if (fmt == "" || fmt == "ll") { std::string type_str; llvm::raw_string_ostream rso(type_str); - ICHECK(module_ != nullptr); + TVM_FFI_ICHECK(module_ != nullptr); module_->print(rso, nullptr); return rso.str(); } else { - LOG(FATAL) << "Do not know how to get source code with format: " << format << "\'"; + TVM_FFI_THROW(InternalError) << "Do not know how to get source code with format: " << format + << "\'"; } return ""; } @@ -338,7 +342,7 @@ void LLVMModuleNode::Init(const IRModule& mod, const Target& target) { auto global_symbol = f->GetAttr(tvm::attr::kGlobalSymbol); bool is_entry_func = f->HasNonzeroAttr(tir::attr::kIsEntryFunc); - ICHECK(global_symbol || !is_entry_func) << "The entry func must be exposed externally."; + TVM_FFI_ICHECK(global_symbol || !is_entry_func) << "The entry func must be exposed externally."; if (global_symbol) { function_names_.push_back(global_symbol.value()); @@ -348,7 +352,7 @@ void LLVMModuleNode::Init(const IRModule& mod, const Target& target) { } } // TODO(@jroesch): follow up on this condition. - // ICHECK(funcs.size() > 0); + // TVM_FFI_ICHECK(funcs.size() > 0); // TODO(tqchen): remove the entry function behavior as it does not // makes sense when we start to use multiple modules. cg->Init("TVMMod", llvm_target.get(), system_lib_prefix, system_lib_prefix.has_value(), false); @@ -416,23 +420,23 @@ void LLVMModuleNode::InitMCJIT() { // create the taget machine auto tm = std::unique_ptr(builder.selectTarget()); if (!IsCompatibleWithHost(tm.get())) { - LOG(FATAL) << "Cannot run module, architecture mismatch"; + TVM_FFI_THROW(InternalError) << "Cannot run module, architecture mismatch"; } // data layout llvm::DataLayout layout(tm->createDataLayout()); - ICHECK(layout == module_->getDataLayout()) + TVM_FFI_ICHECK(layout == module_->getDataLayout()) << "Data layout mismatch between module(" << module_->getDataLayout().getStringRepresentation() << ")" << " and ExecutionEngine (" << layout.getStringRepresentation() << ")"; // create MCJIT mcjit_ee_ = builder.create(tm.release()); - ICHECK(mcjit_ee_ != nullptr) << "Failed to initialize LLVM MCJIT engine for " + TVM_FFI_ICHECK(mcjit_ee_ != nullptr) << "Failed to initialize LLVM MCJIT engine for " #if TVM_LLVM_VERSION >= 210 - << module_->getTargetTriple().str(); + << module_->getTargetTriple().str(); #else - << module_->getTargetTriple(); + << module_->getTargetTriple(); #endif VLOG(2) << "LLVM MCJIT execute " << module_->getModuleIdentifier() << " for triple `" @@ -491,13 +495,13 @@ void LLVMModuleNode::InitORCJIT() { // create the taget machine std::unique_ptr tm = llvm::cantFail(tm_builder.createTargetMachine()); if (!IsCompatibleWithHost(tm.get())) { - LOG(FATAL) << "Cannot run module, architecture mismatch"; + TVM_FFI_THROW(InternalError) << "Cannot run module, architecture mismatch"; } // data layout ffi::String module_name = module_->getModuleIdentifier(); llvm::DataLayout layout(tm->createDataLayout()); - ICHECK(layout == module_->getDataLayout()) + TVM_FFI_ICHECK(layout == module_->getDataLayout()) << "Data layout mismatch between module(" << module_->getDataLayout().getStringRepresentation() << ")" << " and ExecutionEngine (" << layout.getStringRepresentation() << ")"; @@ -558,11 +562,11 @@ void LLVMModuleNode::InitORCJIT() { #endif .create()); - ICHECK(orcjit_ee_ != nullptr) << "Failed to initialize LLVM ORCJIT engine for " + TVM_FFI_ICHECK(orcjit_ee_ != nullptr) << "Failed to initialize LLVM ORCJIT engine for " #if TVM_LLVM_VERSION >= 210 - << module_->getTargetTriple().str(); + << module_->getTargetTriple().str(); #else - << module_->getTargetTriple(); + << module_->getTargetTriple(); #endif // store ctors @@ -573,7 +577,7 @@ void LLVMModuleNode::InitORCJIT() { // resolve system symbols (like pthread, dl, m, etc.) auto gen = llvm::orc::DynamicLibrarySearchGenerator::GetForCurrentProcess(layout.getGlobalPrefix()); - ICHECK(gen) << llvm::toString(gen.takeError()) << "\n"; + TVM_FFI_ICHECK(gen) << llvm::toString(gen.takeError()) << "\n"; orcjit_ee_->getMainJITDylib().addGenerator(std::move(gen.get())); // transfer module to a clone @@ -583,7 +587,7 @@ void LLVMModuleNode::InitORCJIT() { // add the llvm module to run llvm::orc::ThreadSafeModule tsm(std::move(umod), std::move(uctx)); auto err = orcjit_ee_->addIRModule(std::move(tsm)); - ICHECK(!err) << llvm::toString(std::move(err)); + TVM_FFI_ICHECK(!err) << llvm::toString(std::move(err)); VLOG(2) << "LLVM ORCJIT execute " << module_->getModuleIdentifier() << " for triple `" << llvm_target->GetTargetTriple() << "`" @@ -591,7 +595,7 @@ void LLVMModuleNode::InitORCJIT() { // run ctors err = ctorRunner.run(); - ICHECK(!err) << llvm::toString(std::move(err)); + TVM_FFI_ICHECK(!err) << llvm::toString(std::move(err)); if (void** ctx_addr = reinterpret_cast(GetGlobalAddr(ffi::symbol::tvm_ffi_library_ctx, *llvm_target))) { @@ -629,7 +633,7 @@ void* LLVMModuleNode::GetGlobalAddr(const std::string& name, const LLVMTarget& l #endif return reinterpret_cast(addr); } else { - LOG(FATAL) << "Either `mcjit` or `orcjit` are not initialized."; + TVM_FFI_THROW(InternalError) << "Either `mcjit` or `orcjit` are not initialized."; } } return nullptr; @@ -649,7 +653,7 @@ void* LLVMModuleNode::GetFunctionAddr(const std::string& name, #endif return reinterpret_cast(addr); } else { - LOG(FATAL) << "Either `mcjit` or `orcjit` are not initialized."; + TVM_FFI_THROW(InternalError) << "Either `mcjit` or `orcjit` are not initialized."; } } return nullptr; diff --git a/src/target/opt/build_cuda_off.cc b/src/target/opt/build_cuda_off.cc index 339d07fd7338..bf5c5d63d471 100644 --- a/src/target/opt/build_cuda_off.cc +++ b/src/target/opt/build_cuda_off.cc @@ -26,7 +26,7 @@ namespace runtime { ffi::Module CUDAModuleCreate(std::string data, std::string fmt, ffi::Map fmap, std::string cuda_source) { - LOG(FATAL) << "CUDA is not enabled"; + TVM_FFI_THROW(InternalError) << "CUDA is not enabled"; TVM_FFI_UNREACHABLE(); } } // namespace runtime diff --git a/src/target/opt/build_cuda_on.cc b/src/target/opt/build_cuda_on.cc index 8cc1472172c5..4e312e93a462 100644 --- a/src/target/opt/build_cuda_on.cc +++ b/src/target/opt/build_cuda_on.cc @@ -51,12 +51,12 @@ ffi::Module BuildCUDA(IRModule mod, Target target) { ffi::Map functions; for (auto [gvar, base_func] : mod->functions) { - ICHECK(base_func->IsInstance()) << "CodeGenCUDA: Can only take PrimFunc"; + TVM_FFI_ICHECK(base_func->IsInstance()) << "CodeGenCUDA: Can only take PrimFunc"; auto prim_func = Downcast(base_func); auto calling_conv = prim_func->GetAttr(tvm::attr::kCallingConv, Integer(tvm::CallingConv::kDefault)); - ICHECK(calling_conv == CallingConv::kDeviceKernelLaunch || - calling_conv == CallingConv::kDefault) + TVM_FFI_ICHECK(calling_conv == CallingConv::kDeviceKernelLaunch || + calling_conv == CallingConv::kDefault) << "CodeGenCUDA: expect calling_conv equals CallingConv::kDeviceKernelLaunch or " "CallingConv::kDefault"; functions.Set(gvar, prim_func); @@ -80,7 +80,7 @@ ffi::Module BuildCUDA(IRModule mod, Target target) { // Always use Python compilation callback (nvcc or nvrtc) // The C++ NVRTC fallback has been removed in favor of Python-first approach auto f_compile = ffi::Function::GetGlobal("tvm_callback_cuda_compile"); - ICHECK(f_compile != nullptr) + TVM_FFI_ICHECK(f_compile != nullptr) << "tvm_callback_cuda_compile not found. " << "Please ensure TVM Python runtime is properly initialized.\n" << "The Python callback (tvm.contrib.nvcc.tvm_callback_cuda_compile) is required " diff --git a/src/target/opt/build_opencl_off.cc b/src/target/opt/build_opencl_off.cc index 8a6d1b51fdaa..1a27866a4c94 100644 --- a/src/target/opt/build_opencl_off.cc +++ b/src/target/opt/build_opencl_off.cc @@ -34,7 +34,7 @@ ffi::Module OpenCLModuleCreate(std::string data, std::string fmt, ffi::Module OpenCLModuleCreate(const std::unordered_map& shaders, const std::string& spirv_text, ffi::Map fmap) { - LOG(FATAL) << "OpenCLModuleCreate is called but OpenCL is not enabled."; + TVM_FFI_THROW(InternalError) << "OpenCLModuleCreate is called but OpenCL is not enabled."; TVM_FFI_UNREACHABLE(); } diff --git a/src/target/source/codegen_c.cc b/src/target/source/codegen_c.cc index 732774e31849..d3e1cee46eba 100644 --- a/src/target/source/codegen_c.cc +++ b/src/target/source/codegen_c.cc @@ -139,7 +139,7 @@ void CodeGenC::DeclareFunction(const GlobalVar& gvar, const PrimFunc& func) { auto function_name = [&]() -> ffi::String { if (auto global_symbol = func->GetAttr(tvm::attr::kGlobalSymbol)) { auto name = global_symbol.value(); - ICHECK(!func_name_supply_->ContainsName(name)) + TVM_FFI_ICHECK(!func_name_supply_->ContainsName(name)) << "Function " << gvar << " must use global symbol " << name << ", but this name has already been used."; func_name_supply_->ReserveName(name); @@ -161,7 +161,7 @@ void CodeGenC::DeclareFunction(const GlobalVar& gvar, const PrimFunc& func) { ffi::String CodeGenC::GetFunctionName(const GlobalVar& gvar) { auto it = internal_functions_.find(gvar); - ICHECK(it != internal_functions_.end()) + TVM_FFI_ICHECK(it != internal_functions_.end()) << "Attempted to find name of " << gvar << ", but no function with this GlobalVar has been declared"; return it->second; @@ -327,7 +327,7 @@ std::string CodeGenC::GetStructRef(DataType t, const PrimExpr& buffer, const Pri os << "device.device_type"; break; default: - LOG(FATAL) << "unknown field code"; + TVM_FFI_THROW(InternalError) << "unknown field code"; } os << ')'; return os.str(); @@ -355,7 +355,7 @@ std::string CodeGenC::GetStructRef(DataType t, const PrimExpr& buffer, const Pri } else if (t.is_int()) { os << "v_int64"; } else { - LOG(FATAL) << "Do not know how to handle type" << t; + TVM_FFI_THROW(InternalError) << "Do not know how to handle type" << t; } os << ")"; return os.str(); @@ -375,7 +375,7 @@ void CodeGenC::RegisterHandleType(const VarNode* buf_var, DataType t) { if (it == handle_data_type_.end()) { handle_data_type_[buf_var] = t; } else { - ICHECK(it->second == t) << "conflicting buf var type"; + TVM_FFI_ICHECK(it->second == t) << "conflicting buf var type"; } } @@ -414,13 +414,15 @@ std::string CodeGenC::CastFromTo(std::string value, DataType from, DataType targ return os.str(); } -void CodeGenC::BindThreadIndex(const IterVar& iv) { LOG(FATAL) << "not implemented"; } +void CodeGenC::BindThreadIndex(const IterVar& iv) { + TVM_FFI_THROW(InternalError) << "not implemented"; +} void CodeGenC::PrintStorageSync(const CallNode* op) { // NOLINT(*) } void CodeGenC::PrintStorageScope(const std::string& scope, std::ostream& os) { // NOLINT(*) - ICHECK_EQ(scope, "global"); + TVM_FFI_ICHECK_EQ(scope, "global"); } inline void PrintConst(const IntImmNode* op, std::ostream& os, CodeGenC* p) { // NOLINT(*) @@ -468,7 +470,7 @@ inline void PrintConst(const FloatImmNode* op, std::ostream& os, CodeGenC* p) { break; } default: - LOG(FATAL) << "Bad bit-width for float: " << op->dtype << "\n"; + TVM_FFI_THROW(InternalError) << "Bad bit-width for float: " << op->dtype << "\n"; } } @@ -510,7 +512,7 @@ inline void PrintBinaryIntrinsic(const CallNode* op, const char* opstr, std::ostream& os, // NOLINT(*) CodeGenC* p) { if (op->dtype.lanes() == 1) { - ICHECK_EQ(op->args.size(), 2U); + TVM_FFI_ICHECK_EQ(op->args.size(), 2U); os << '('; p->PrintExpr(op->args[0], os); os << opstr; @@ -544,14 +546,14 @@ void CodeGenC::VisitExpr_(const ModNode* op, std::ostream& os) { // NOLINT(*) if (op->dtype.is_int() || op->dtype.is_uint()) { PrintBinaryExpr(op, "%", os, this); } else { - ICHECK(op->dtype.is_float()) << "Expected floating point or integer dtype in Mod, but got " - << op->dtype; + TVM_FFI_ICHECK(op->dtype.is_float()) + << "Expected floating point or integer dtype in Mod, but got " << op->dtype; if (op->dtype.bits() == 32) { PrintBinaryExpr(op, "fmodf", os, this); } else if (op->dtype.bits() == 64) { PrintBinaryExpr(op, "fmod", os, this); } else { - ICHECK(false) + TVM_FFI_ICHECK(false) << "Non single or double precision floating point in Mod, expected 32 or 64 bits but got " << op->dtype.bits() << " bits."; } @@ -617,7 +619,7 @@ void CodeGenC::VisitExpr_(const CallNode* op, std::ostream& os) { // NOLINT(*) } else if (op->op.same_as(builtin::break_loop())) { os << "break;"; } else if (op->op.same_as(builtin_call_extern_) || op->op.same_as(builtin_call_pure_extern_)) { - ICHECK_GE(op->args.size(), 1U); + TVM_FFI_ICHECK_GE(op->args.size(), 1U); auto func = Downcast(op->args[0]); this->PrintCallExtern(GetType(ffi::GetRef(op)), func->value, op->args, true, os); @@ -638,7 +640,7 @@ void CodeGenC::VisitExpr_(const CallNode* op, std::ostream& os) { // NOLINT(*) } else if (op->op.same_as(builtin::bitwise_and())) { PrintBinaryIntrinsic(op, " & ", os, this); } else if (op->op.same_as(builtin::large_uint_imm())) { - ICHECK_EQ(op->args.size(), 2U); + TVM_FFI_ICHECK_EQ(op->args.size(), 2U); uint64_t low = static_cast(Downcast(op->args[0])->value); uint64_t high = static_cast(Downcast(op->args[1])->value); uint64_t val = (high << 32U) | low; @@ -648,7 +650,7 @@ void CodeGenC::VisitExpr_(const CallNode* op, std::ostream& os) { // NOLINT(*) } else if (op->op.same_as(builtin::bitwise_or())) { PrintBinaryIntrinsic(op, " | ", os, this); } else if (op->op.same_as(builtin::bitwise_not())) { - ICHECK_EQ(op->args.size(), 1U); + TVM_FFI_ICHECK_EQ(op->args.size(), 1U); os << "(~"; this->PrintExpr(op->args[0], os); os << ')'; @@ -686,19 +688,20 @@ void CodeGenC::VisitExpr_(const CallNode* op, std::ostream& os) { // NOLINT(*) os << result; } else if (op->op.same_as(builtin::address_of())) { const BufferLoadNode* load = op->args[0].as(); - ICHECK(op->args.size() == 1 && load); - ICHECK_EQ(load->indices.size(), 1) << "CodeGenC only supports flat memory allocations."; + TVM_FFI_ICHECK(op->args.size() == 1 && load); + TVM_FFI_ICHECK_EQ(load->indices.size(), 1) + << "CodeGenC only supports flat memory allocations."; os << "(&(" << GetBufferRef(load->dtype, load->buffer.get(), load->indices[0]) << "))"; } else if (op->op.same_as(builtin::tvm_struct_get())) { - ICHECK_EQ(op->args.size(), 3U); + TVM_FFI_ICHECK_EQ(op->args.size(), 3U); os << GetStructRef(op->dtype, op->args[0], op->args[1], op->args[2].as()->value); } else if (op->op.same_as(builtin::isnullptr())) { - ICHECK_EQ(op->args.size(), 1U); + TVM_FFI_ICHECK_EQ(op->args.size(), 1U); os << "("; this->PrintExpr(op->args[0], os); os << " == NULL)"; } else if (op->op.same_as(builtin::handle_add_byte_offset())) { - ICHECK_EQ(op->args.size(), 2U); + TVM_FFI_ICHECK_EQ(op->args.size(), 2U); os << "((void*)((char*)"; this->PrintExpr(op->args[0], os); os << " + "; @@ -707,8 +710,8 @@ void CodeGenC::VisitExpr_(const CallNode* op, std::ostream& os) { // NOLINT(*) } else if (op->op.same_as(builtin::reinterpret())) { auto target_dtype = op->dtype; auto source_dtype = op->args[0]->dtype; - CHECK_EQ(target_dtype.lanes() * target_dtype.bits(), - source_dtype.lanes() * source_dtype.bits()) + TVM_FFI_ICHECK_EQ(target_dtype.lanes() * target_dtype.bits(), + source_dtype.lanes() * source_dtype.bits()) << "reinterpret expects source and target to have the same number of bits"; int ssa_scope = BeginScope(); std::string rhs = SSAGetID(PrintExpr(op->args[0]), source_dtype); @@ -723,24 +726,25 @@ void CodeGenC::VisitExpr_(const CallNode* op, std::ostream& os) { // NOLINT(*) this->PrintExpr(op->args[0], os); os << ")"; } else if (op->op.same_as(builtin::lookup_param())) { - ICHECK_EQ(op->args.size(), 1); + TVM_FFI_ICHECK_EQ(op->args.size(), 1); const StringImmNode* str = op->args[0].as(); - ICHECK(str != nullptr); + TVM_FFI_ICHECK(str != nullptr); os << "__tvm_param__" << str->value; } else if (op->op.same_as(builtin::tvm_thread_invariant())) { os << "("; this->PrintExpr(op->args[0], os); os << ")"; } else { - LOG(FATAL) << "Unresolved call " << op->op; + TVM_FFI_THROW(InternalError) << "Unresolved call " << op->op; } } else if (auto opt = op->op.as()) { auto gvar = opt.value(); auto callee_name = GetFunctionName(gvar); PrintCallExtern(GetType(ffi::GetRef(op)), callee_name, op->args, false, os); } else { - LOG(FATAL) << "CodeGenC: Unknown operation " << op->op << " is neither a recognized built-in, " - << "nor a GlobalVar reference to another function in the IRModule"; + TVM_FFI_THROW(InternalError) << "CodeGenC: Unknown operation " << op->op + << " is neither a recognized built-in, " + << "nor a GlobalVar reference to another function in the IRModule"; } } @@ -764,8 +768,8 @@ void CodeGenC::PrintVecBinaryOp(const std::string& op, DataType t, PrimExpr lhs, void CodeGenC::VisitStmt_(const DeclBufferNode* op) { this->PrintStmt(op->body); } void CodeGenC::VisitExpr_(const BufferLoadNode* op, std::ostream& os) { // NOLINT(*) - ICHECK_EQ(op->indices.size(), 1) << "Load from non-flat memory not supported."; - ICHECK(!op->predicate.defined()) << "Predicated buffer load is not supported."; + TVM_FFI_ICHECK_EQ(op->indices.size(), 1) << "Load from non-flat memory not supported."; + TVM_FFI_ICHECK(!op->predicate.defined()) << "Predicated buffer load is not supported."; DataType value_dtype = op->dtype; PrimExpr index = op->indices[0]; @@ -782,7 +786,7 @@ void CodeGenC::VisitExpr_(const BufferLoadNode* op, std::ostream& os) { // NOLI arith::PVar base; if (arith::ramp(base, 1, op->dtype.lanes()).Match(index)) { const RampNode* ramp = index.as(); - ICHECK(ramp); + TVM_FFI_ICHECK(ramp); arith::ModularSet me = arith::Analyzer().modular_set(ramp->base); // The condition: {k * coeff + base} divisible by the alignment for any k if (me->coeff % op->dtype.lanes() == 0 && me->base % op->dtype.lanes() == 0) { @@ -829,8 +833,8 @@ void CodeGenC::VisitExpr_(const BufferLoadNode* op, std::ostream& os) { // NOLI } void CodeGenC::VisitStmt_(const BufferStoreNode* op) { - ICHECK_EQ(op->indices.size(), 1) << "Store to non-flat memory not supported."; - ICHECK(!op->predicate.defined()) << "Predicated buffer store is not supported."; + TVM_FFI_ICHECK_EQ(op->indices.size(), 1) << "Store to non-flat memory not supported."; + TVM_FFI_ICHECK(!op->predicate.defined()) << "Predicated buffer store is not supported."; DataType value_dtype = op->value.dtype(); DataType element_dtype = op->buffer->dtype; @@ -888,14 +892,14 @@ void CodeGenC::VisitStmt_(const BufferStoreNode* op) { void CodeGenC::VisitExpr_(const LetNode* op, std::ostream& os) { // NOLINT(*) auto it = let_binding_.find(op->var); if (it != let_binding_.end()) { - ICHECK(deep_equal_(it->second->value, op->value)) + TVM_FFI_ICHECK(deep_equal_(it->second->value, op->value)) << "Let cannot bind the same var to two different values"; } else { let_binding_[op->var] = op; } std::string value = PrintExpr(op->value); if (print_ssa_form_) { - ICHECK(!var_idmap_.count(op->var.get())); + TVM_FFI_ICHECK(!var_idmap_.count(op->var.get())); var_idmap_[op->var.get()] = value; } else { PrintIndent(); @@ -914,7 +918,7 @@ void CodeGenC::VisitExpr_(const LetNode* op, std::ostream& os) { // NOLINT(*) // We do this because it is hard to completely avoid a same LetNode appearing // at different places. bool removed = var_idmap_.erase(op->var.get()); - ICHECK(removed); + TVM_FFI_ICHECK(removed); } void CodeGenC::VisitExpr_(const RampNode* op, std::ostream& os) { // NOLINT(*) @@ -976,13 +980,13 @@ void CodeGenC::VisitExpr_(const ShuffleNode* op, std::ostream& os) { // NOLINT( } if (op->indices.size() == 1) { // This is an extract element - CHECK(op->indices[0]->IsInstance()) + TVM_FFI_ICHECK(op->indices[0]->IsInstance()) << "The ShuffleNode indices are expected to be constants at codegen time. However, " << "a non-constant index is " << op->indices[0] << ". Please avoid using ShuffleNode or eliminate the ShuffleNode with loop unroll or " << "vectorize."; int64_t idx = Downcast(op->indices[0])->value; - ICHECK_LT(idx, concat_vec.size()); + TVM_FFI_ICHECK_LT(idx, concat_vec.size()); os << concat_vec[idx]; } else { // Print the shuffle as vector constructor @@ -991,7 +995,7 @@ void CodeGenC::VisitExpr_(const ShuffleNode* op, std::ostream& os) { // NOLINT( os << '('; for (size_t i = 0; i < op->indices.size(); ++i) { if (i != 0) os << ", "; - CHECK(op->indices[i]->IsInstance()) + TVM_FFI_ICHECK(op->indices[i]->IsInstance()) << "The ShuffleNode indices are expected to be constants at codegen time. However, " << "a non-constant index is " << op->indices[i] << ". Please avoid using ShuffleNode or eliminate the ShuffleNode with loop unroll or " @@ -1003,7 +1007,7 @@ void CodeGenC::VisitExpr_(const ShuffleNode* op, std::ostream& os) { // NOLINT( } void CodeGenC::VisitExpr_(const BroadcastNode* op, std::ostream& os) { // NOLINT(*) - LOG(FATAL) << "Broadcast: not supported "; + TVM_FFI_THROW(InternalError) << "Broadcast: not supported "; } void CodeGenC::VisitExpr_(const SelectNode* op, std::ostream& os) { // NOLINT(*) @@ -1019,7 +1023,7 @@ void CodeGenC::VisitExpr_(const SelectNode* op, std::ostream& os) { // NOLINT(* void CodeGenC::VisitStmt_(const LetStmtNode* op) { std::string value = PrintExpr(op->value); if (print_ssa_form_) { - ICHECK(!var_idmap_.count(op->var.get())); + TVM_FFI_ICHECK(!var_idmap_.count(op->var.get())); var_idmap_[op->var.get()] = value; } else { PrintIndent(); @@ -1037,12 +1041,12 @@ void CodeGenC::VisitStmt_(const LetStmtNode* op) { } void CodeGenC::VisitStmt_(const AllocateNode* op) { - ICHECK(!is_zero(op->condition)); + TVM_FFI_ICHECK(!is_zero(op->condition)); std::string vid = AllocVarID(op->buffer_var.get()); this->PrintIndent(); size_t constant_size = op->ConstantAllocationSize(); - ICHECK_GT(constant_size, 0) << "Can only handle constant size stack allocation for now"; + TVM_FFI_ICHECK_GT(constant_size, 0) << "Can only handle constant size stack allocation for now"; auto scope = GetPtrStorageScope(op->buffer_var); alloc_storage_scope_[op->buffer_var.get()] = scope; @@ -1065,11 +1069,11 @@ void CodeGenC::VisitStmt_(const AttrStmtNode* op) { } } else if (op->attr_key == tir::attr::volatile_scope) { const VarNode* v = op->node.as(); - ICHECK(v); + TVM_FFI_ICHECK(v); volatile_buf_.insert(v); } else if (op->attr_key == tir::attr::pragma_import_c) { const StringImmNode* value = op->value.as(); - ICHECK(value != nullptr); + TVM_FFI_ICHECK(value != nullptr); decl_stream << value->value; } this->PrintStmt(op->body); @@ -1080,7 +1084,7 @@ void CodeGenC::VisitStmt_(const AssertStmtNode* op) { PrintIndent(); if (const auto* str = op->message.as()) { // GLOG style check - stream << "ICHECK(" << cond << ") << \"" << str->value << "\";\n"; + stream << "TVM_FFI_ICHECK(" << cond << ") << \"" << str->value << "\";\n"; } else { stream << "assert(" << cond << ");\n"; } @@ -1160,7 +1164,7 @@ void CodeGenC::VisitStmt_(const EvaluateNode* op) { this->PrintStorageSync(call); return; } else if (call->op.same_as(builtin::tvm_struct_set())) { - ICHECK_EQ(call->args.size(), 4); + TVM_FFI_ICHECK_EQ(call->args.size(), 4); int kind = call->args[2].as()->value; DataType store_dtype = call->args[3].dtype(); std::string ref = GetStructRef(store_dtype, call->args[0], call->args[1], kind); @@ -1196,7 +1200,7 @@ void CodeGenC::VisitStmt_(const EvaluateNode* op) { } void CodeGenC::PrintVecElemLoadExpr(DataType t, int i, const std::string& value, std::ostream& os) { - ICHECK_GT(t.lanes(), 1); + TVM_FFI_ICHECK_GT(t.lanes(), 1); if (t.bits() == 8 && (t.is_int() || t.is_uint())) { if (i != 0) { os << "|"; diff --git a/src/target/source/codegen_c_host.cc b/src/target/source/codegen_c_host.cc index 12a8d66bba9b..cb6ba238efee 100644 --- a/src/target/source/codegen_c_host.cc +++ b/src/target/source/codegen_c_host.cc @@ -75,7 +75,7 @@ void CodeGenCHost::AddFunction(const GlobalVar& gvar, const PrimFunc& func, emit_fwd_func_decl_ = emit_fwd_func_decl; CodeGenC::AddFunction(gvar, func); if (func->HasNonzeroAttr(tir::attr::kIsEntryFunc) && !has_tvm_ffi_main_func_) { - ICHECK(global_symbol.has_value()) + TVM_FFI_ICHECK(global_symbol.has_value()) << "CodeGenCHost: The entry func must have the global_symbol attribute, " << "but function " << gvar << " only has attributes " << func->attrs; @@ -123,7 +123,7 @@ void CodeGenCHost::PrintFuncPrefix(std::ostream& os) { // NOLINT(*) void CodeGenCHost::PrintType(DataType t, std::ostream& os) { // NOLINT(*) int lanes = t.lanes(); if (t.is_handle()) { - ICHECK_EQ(lanes, 1) << "does not support vector types"; + TVM_FFI_ICHECK_EQ(lanes, 1) << "does not support vector types"; os << "void*"; return; } @@ -186,7 +186,7 @@ void CodeGenCHost::PrintType(DataType t, std::ostream& os) { // NOLINT(*) return; } } - LOG(FATAL) << "Cannot convert type " << t << " to C type"; + TVM_FFI_THROW(InternalError) << "Cannot convert type " << t << " to C type"; } void CodeGenCHost::VisitExpr_(const BroadcastNode* op, std::ostream& os) { // NOLINT(*) @@ -223,12 +223,12 @@ void CodeGenCHost::PrintGetFuncFromBackend(const std::string& func_name, void CodeGenCHost::PrintCallPacked(const CallNode* op) { const StringImmNode* func_name = op->args[0].as(); - ICHECK(func_name != nullptr) + TVM_FFI_ICHECK(func_name != nullptr) << "tvm_call_[c]packed_lowered expects first argument as function name"; int64_t begin = op->args[2].as()->value; int64_t end = op->args[3].as()->value; int64_t num_args = end - begin; - ICHECK_GE(num_args, 0); + TVM_FFI_ICHECK_GE(num_args, 0); std::string packed_func_name; if (op->op.same_as(builtin::tvm_call_packed_lowered())) { @@ -236,7 +236,7 @@ void CodeGenCHost::PrintCallPacked(const CallNode* op) { this->PrintGetFuncFromBackend(func_name->value, packed_func_name); } else { // directly use the original symbol - ICHECK(op->op.same_as(builtin::tvm_call_cpacked_lowered())); + TVM_FFI_ICHECK(op->op.same_as(builtin::tvm_call_cpacked_lowered())); packed_func_name = ffi::symbol::tvm_ffi_symbol_prefix + func_name->value; } @@ -269,7 +269,7 @@ void CodeGenCHost::PrintCallPacked(const CallNode* op) { std::string CodeGenCHost::GetPackedName(const CallNode* op) { const StringImmNode* s = op->args[0].as(); - ICHECK(s != nullptr) << "tvm_call_packed_lowered expects first argument as function name"; + TVM_FFI_ICHECK(s != nullptr) << "tvm_call_packed_lowered expects first argument as function name"; std::string func_name = s->value; std::string packed_func_name = func_name + "_packed"; std::string unique_name; @@ -289,7 +289,7 @@ void CodeGenCHost::VisitExpr_(const CallNode* op, std::ostream& os) { // NOLINT std::string stack_name = name_supply_->FreshName("stack"); const std::string& type = op->args[0].as()->value; const IntImmNode* num = op->args[1].as(); - ICHECK(num != nullptr); + TVM_FFI_ICHECK(num != nullptr); static_assert(alignof(TVMFFIAny) % alignof(DLTensor) == 0, "invariant"); size_t unit = sizeof(TVMFFIAny); size_t size = 0; @@ -300,7 +300,7 @@ void CodeGenCHost::VisitExpr_(const CallNode* op, std::ostream& os) { // NOLINT } else if (type == "array") { size = (num->value * sizeof(DLTensor) + unit - 1) / unit; } else { - LOG(FATAL) << "Unknown stack alloca type " << type; + TVM_FFI_THROW(InternalError) << "Unknown stack alloca type " << type; } this->PrintIndent(); this->stream << "TVMFFIAny " << stack_name << "[" << size << "];\n"; @@ -381,7 +381,7 @@ ffi::Module BuildCHost(IRModule mod, Target target) { std::vector> funcs; for (auto [gvar, base_func] : mod->functions) { - ICHECK(base_func->IsInstance()) << "CodegenCHost: Can only take PrimFunc"; + TVM_FFI_ICHECK(base_func->IsInstance()) << "CodegenCHost: Can only take PrimFunc"; auto prim_func = Downcast(base_func); funcs.push_back({gvar, prim_func}); } diff --git a/src/target/source/codegen_cuda.cc b/src/target/source/codegen_cuda.cc index 32f0907ee2a8..3d5beacc63a9 100644 --- a/src/target/source/codegen_cuda.cc +++ b/src/target/source/codegen_cuda.cc @@ -56,7 +56,8 @@ std::string GetFP8Type(DataType type) { } else if (lanes == 16) { vec = "x16"; } else { - LOG(FATAL) << "Only support scalar and vector types of width (2, 4, 8, 16) for FP8"; + TVM_FFI_THROW(InternalError) + << "Only support scalar and vector types of width (2, 4, 8, 16) for FP8"; } stream << "__nv_fp8"; std::string suffix; @@ -67,7 +68,7 @@ std::string GetFP8Type(DataType type) { } else if (type.code() == DataType::kFloat8_e8m0fnu) { suffix = "_e8m0"; } else { - LOG(FATAL) << "Unsupported FP8 type in CUDA codegen"; + TVM_FFI_THROW(InternalError) << "Unsupported FP8 type in CUDA codegen"; } stream << vec << suffix; return stream.str(); @@ -88,7 +89,7 @@ std::string GetFP6Type(DataType type) { } else if (lanes == 16) { vec = "x16"; } else { - LOG(FATAL) << "Only support scalar and vector types of width (2, 4) for FP6"; + TVM_FFI_THROW(InternalError) << "Only support scalar and vector types of width (2, 4) for FP6"; } stream << "__nv_fp6"; std::string suffix; @@ -97,7 +98,7 @@ std::string GetFP6Type(DataType type) { } else if (type.code() == DataType::kFloat6_e3m2fn) { suffix = "_e3m2"; } else { - LOG(FATAL) << "Unsupported FP6 type in CUDA codegen"; + TVM_FFI_THROW(InternalError) << "Unsupported FP6 type in CUDA codegen"; } stream << vec << suffix; return stream.str(); @@ -118,14 +119,14 @@ std::string GetFP4Type(DataType type) { } else if (lanes == 16) { vec = "x16"; } else { - LOG(FATAL) << "Only support scalar and vector types of width (2, 4) for FP4"; + TVM_FFI_THROW(InternalError) << "Only support scalar and vector types of width (2, 4) for FP4"; } stream << "__nv_fp4"; std::string suffix; if (type.code() == DataType::kFloat4_e2m1fn) { suffix = "_e2m1"; } else { - LOG(FATAL) << "Unsupported FP4 type in CUDA codegen"; + TVM_FFI_THROW(InternalError) << "Unsupported FP4 type in CUDA codegen"; } stream << vec << suffix; return stream.str(); @@ -137,7 +138,7 @@ void CodeGenCUDA::Init(bool output_ssa) { CodeGenC::Init(output_ssa); vid_global_barrier_state_ = name_supply_->FreshName(runtime::symbol::tvm_global_barrier_state); vid_global_barrier_expect_ = name_supply_->FreshName("__barrier_expect"); - ICHECK_EQ(vid_global_barrier_state_, runtime::symbol::tvm_global_barrier_state); + TVM_FFI_ICHECK_EQ(vid_global_barrier_state_, runtime::symbol::tvm_global_barrier_state); } void CodeGenCUDA::PrintFunctionSignature(const ffi::String& function_name, const PrimFunc& func, @@ -149,7 +150,8 @@ void CodeGenCUDA::PrintFunctionSignature(const ffi::String& function_name, const } else if (calling_conv == CallingConv::kDefault) { os << "extern \"C\" __device__ "; } else { - LOG(FATAL) << "Unsupported calling convention for CUDA codegen: " << calling_conv; + TVM_FFI_THROW(InternalError) << "Unsupported calling convention for CUDA codegen: " + << calling_conv; } CodeGenC::PrintFunctionSignature(function_name, func, os); } @@ -333,14 +335,14 @@ void CodeGenCUDA::VisitStmt_(const tir::ForNode* op) { } void CodeGenCUDA::BindThreadIndex(const IterVar& iv) { - ICHECK(!var_idmap_.count(iv->var.get())); + TVM_FFI_ICHECK(!var_idmap_.count(iv->var.get())); var_idmap_[iv->var.get()] = CastFromTo(iv->thread_tag, DataType::UInt(32), iv->var.dtype()); } void CodeGenCUDA::PrintType(DataType t, std::ostream& os) { // NOLINT(*) int lanes = t.lanes(); if (t.is_handle()) { - ICHECK(t.is_scalar()) << "do not yet support vector types"; + TVM_FFI_ICHECK(t.is_scalar()) << "do not yet support vector types"; os << "void*"; return; } @@ -358,7 +360,7 @@ void CodeGenCUDA::PrintType(DataType t, std::ostream& os) { // NOLINT(*) if (t.is_scalar()) { os << "half"; } else if (lanes <= 8) { - ICHECK_EQ(lanes % 2, 0) << "Only support an even number of lanes for half type"; + TVM_FFI_ICHECK_EQ(lanes % 2, 0) << "Only support an even number of lanes for half type"; if (lanes <= 4) { os << "half" << lanes; } else { @@ -379,7 +381,7 @@ void CodeGenCUDA::PrintType(DataType t, std::ostream& os) { // NOLINT(*) // f8.v1 is emitted as *(float2*)(&(ul4.x)).x // f8.v2 is emitted as *(float2*)(&(ul4.x)).y // - ICHECK_EQ(lanes % 2, 0) << "only support even lane for float type with lanes > 4"; + TVM_FFI_ICHECK_EQ(lanes % 2, 0) << "only support even lane for float type with lanes > 4"; os << "ulonglong" << lanes / 2; } else { fail = true; @@ -403,7 +405,7 @@ void CodeGenCUDA::PrintType(DataType t, std::ostream& os) { // NOLINT(*) if (t.is_scalar()) { os << "nv_bfloat16"; } else if (lanes <= 8) { - ICHECK_EQ(lanes % 2, 0) << "only support even lane for bfloat16 type"; + TVM_FFI_ICHECK_EQ(lanes % 2, 0) << "only support even lane for bfloat16 type"; if (lanes <= 4) { os << "nv_bfloat16" << lanes; } else { @@ -467,7 +469,7 @@ void CodeGenCUDA::PrintType(DataType t, std::ostream& os) { // NOLINT(*) os << "int"; return; } else { - LOG(FATAL) << "Cannot convert type " << t << " to CUDA type!"; + TVM_FFI_THROW(InternalError) << "Cannot convert type " << t << " to CUDA type!"; } } case 4: { @@ -491,7 +493,7 @@ void CodeGenCUDA::PrintType(DataType t, std::ostream& os) { // NOLINT(*) os << "int8"; return; } else { - LOG(FATAL) << "Cannot convert type " << t << " to CUDA type!"; + TVM_FFI_THROW(InternalError) << "Cannot convert type " << t << " to CUDA type!"; } } case 8: { @@ -535,7 +537,8 @@ void CodeGenCUDA::PrintType(DataType t, std::ostream& os) { // NOLINT(*) // s4.z is emitted as *(short2*)(&(i2.y)).x // s4.w is emitted as *(short2*)(&(i2.y)).y // - ICHECK_EQ(t.lanes() % 2, 0) << "only support even lane for shorT type with lanes > 4"; + TVM_FFI_ICHECK_EQ(t.lanes() % 2, 0) + << "only support even lane for shorT type with lanes > 4"; os << "int" << t.lanes() / 2; } else { fail = true; @@ -558,7 +561,7 @@ void CodeGenCUDA::PrintType(DataType t, std::ostream& os) { // NOLINT(*) // i8.v1 is emitted as *(int2*)(&(l4.x)).x // i8.v2 is emitted as *(int2*)(&(l4.x)).y // - ICHECK_EQ(lanes % 2, 0) << "only support even lane for int32 type with lanes > 4"; + TVM_FFI_ICHECK_EQ(lanes % 2, 0) << "only support even lane for int32 type with lanes > 4"; os << "longlong" << lanes / 2; } else { fail = true; @@ -592,7 +595,7 @@ void CodeGenCUDA::PrintType(DataType t, std::ostream& os) { // NOLINT(*) return; } } - LOG(FATAL) << "Cannot convert type " << t << " to CUDA type"; + TVM_FFI_THROW(InternalError) << "Cannot convert type " << t << " to CUDA type"; } void CodeGenCUDA::PrintVecConstructor(DataType t, std::ostream& os) { @@ -643,7 +646,7 @@ void CodeGenCUDA::PrintVecElemLoad(const std::string& vec, DataType t, int i, } static const char access[] = {'x', 'y', 'z', 'w'}; - ICHECK(i >= 0 && i < (t.bits() == 8 ? 16 : (t.bits() == 16 || t.bits() == 32) ? 8 : 4)); + TVM_FFI_ICHECK(i >= 0 && i < (t.bits() == 8 ? 16 : (t.bits() == 16 || t.bits() == 32) ? 8 : 4)); if (t.bits() == 8 && (t.is_int() || t.is_uint())) { std::string type_name = t.is_int() ? "signed char" : "unsigned char"; if (t.lanes() == 2 || t.lanes() == 3) { @@ -681,7 +684,7 @@ void CodeGenCUDA::PrintVecElemLoad(const std::string& vec, DataType t, int i, type_name = "float"; } } - ICHECK(!type_name.empty()); + TVM_FFI_ICHECK(!type_name.empty()); os << "((" << type_name << "2*)(&(" << vec << "." << access[i / 2] << ")))->" << access[i % 2]; } else if (t.is_float4_e2m1fn()) { os << "([](__nv_fp4_storage_t v) { __nv_fp4_e2m1 t; t.__x = v; return t; })((" << vec @@ -695,7 +698,7 @@ void CodeGenCUDA::PrintVecElemStore(const std::string& vec, DataType t, int i, const std::string& value) { this->PrintIndent(); static const char access[] = {'x', 'y', 'z', 'w'}; - ICHECK(i >= 0 && i < (t.bits() == 8 ? 16 : (t.bits() == 16 || t.bits() == 32) ? 8 : 4)); + TVM_FFI_ICHECK(i >= 0 && i < (t.bits() == 8 ? 16 : (t.bits() == 16 || t.bits() == 32) ? 8 : 4)); if (t.bits() == 8 && (t.is_int() || t.is_uint())) { if (t.lanes() == 2 || t.lanes() == 3) { stream << vec << '.' << access[i % t.lanes()] << "=" @@ -738,7 +741,7 @@ void CodeGenCUDA::PrintVecElemStore(const std::string& vec, DataType t, int i, type_name = "float"; } } - ICHECK(!type_name.empty()); + TVM_FFI_ICHECK(!type_name.empty()); stream << "((" << type_name << "2*)(&(" << vec << "." << access[i / 2] << ")))->" << access[i % 2] << " = " << value << ";\n"; } else { @@ -787,8 +790,9 @@ void CodeGenCUDA::PrintStorageSync(const CallNode* op) { } void CodeGenCUDA::PrintStorageScope(const std::string& scope, std::ostream& os) { // NOLINT(*) - ICHECK_NE(scope, "global") << "Cannot allocate global memory when targeting CUDA. You must pass " - "all global arrays as input instead"; + TVM_FFI_ICHECK_NE(scope, "global") + << "Cannot allocate global memory when targeting CUDA. You must pass " + "all global arrays as input instead"; if (scope == "shared") { os << "__shared__ "; } else if (scope == "shared.dyn") { @@ -816,7 +820,7 @@ std::string CodeGenCUDA::CastFromTo(std::string value, DataType from, DataType t void CodeGenCUDA::VisitExpr_(const CastNode* op, std::ostream& os) { DataType from_ty = op->value.dtype(); DataType target_ty = op->dtype; - ICHECK_EQ(target_ty.lanes(), from_ty.lanes()); + TVM_FFI_ICHECK_EQ(target_ty.lanes(), from_ty.lanes()); // Emit simple C-style type conversion. if (from_ty.is_scalar()) return CodeGenC::VisitExpr_(op, os); @@ -935,7 +939,7 @@ void CodeGenCUDA::VisitExpr_(const CallNode* op, std::ostream& os) { if (op->op.same_as(builtin::tvm_fill_fragment())) { need_mma_h_ = true; - ICHECK_EQ(op->args.size(), 6U); + TVM_FFI_ICHECK_EQ(op->args.size(), 6U); os << "nvcuda::wmma::fill_fragment("; this->PrintExpr(op->args[0], os); os << "["; @@ -945,7 +949,7 @@ void CodeGenCUDA::VisitExpr_(const CallNode* op, std::ostream& os) { os << ")"; } else if (op->op.same_as(builtin::tvm_load_matrix_sync())) { need_mma_h_ = true; - ICHECK_EQ(op->args.size(), 8U); + TVM_FFI_ICHECK_EQ(op->args.size(), 8U); os << "nvcuda::wmma::load_matrix_sync("; this->PrintExpr(op->args[0], os); os << "["; @@ -957,7 +961,7 @@ void CodeGenCUDA::VisitExpr_(const CallNode* op, std::ostream& os) { os << ")"; } else if (op->op.same_as(builtin::tvm_store_matrix_sync())) { need_mma_h_ = true; - ICHECK_EQ(op->args.size(), 8U); + TVM_FFI_ICHECK_EQ(op->args.size(), 8U); os << "nvcuda::wmma::store_matrix_sync("; this->PrintExpr(op->args[5], os); os << ", "; @@ -969,12 +973,12 @@ void CodeGenCUDA::VisitExpr_(const CallNode* op, std::ostream& os) { if (const StringImmNode* str = op->args[7].as()) { os << ", nvcuda::wmma::mem_" << str->value; } else { - LOG(FATAL) << "Invalid parameters"; + TVM_FFI_THROW(InternalError) << "Invalid parameters"; } os << ")"; } else if (op->op.same_as(builtin::tvm_mma_sync())) { need_mma_h_ = true; - ICHECK_EQ(op->args.size(), 8U); + TVM_FFI_ICHECK_EQ(op->args.size(), 8U); os << "nvcuda::wmma::mma_sync("; for (int i = 0; i < 4; ++i) { this->PrintExpr(op->args[i * 2], os); @@ -984,7 +988,7 @@ void CodeGenCUDA::VisitExpr_(const CallNode* op, std::ostream& os) { } } else if (op->op.same_as(builtin::tvm_bmma_sync())) { need_mma_h_ = true; - ICHECK_EQ(op->args.size(), 8U); + TVM_FFI_ICHECK_EQ(op->args.size(), 8U); os << "nvcuda::wmma::bmma_sync("; for (int i = 0; i < 4; ++i) { this->PrintExpr(op->args[i * 2], os); @@ -1007,7 +1011,7 @@ void CodeGenCUDA::VisitExpr_(const CallNode* op, std::ostream& os) { // arg 11: C accumulator index // arg 12: saturate // arg 13: (optional) 1-bit operator (xor or and) - ICHECK(op->args.size() == 13U || op->args.size() == 14U); + TVM_FFI_ICHECK(op->args.size() == 13U || op->args.size() == 14U); std::string shape = Downcast(op->args[0])->value; std::string A_layout = Downcast(op->args[1])->value; std::string B_layout = Downcast(op->args[2])->value; @@ -1044,7 +1048,7 @@ void CodeGenCUDA::VisitExpr_(const CallNode* op, std::ostream& os) { // arg 13: metadata index // arg 14: sparse_selector // arg 15: saturate - ICHECK_EQ(op->args.size(), 16U); + TVM_FFI_ICHECK_EQ(op->args.size(), 16U); std::string shape = Downcast(op->args[0])->value; std::string A_layout = Downcast(op->args[1])->value; std::string B_layout = Downcast(op->args[2])->value; @@ -1073,7 +1077,7 @@ void CodeGenCUDA::VisitExpr_(const CallNode* op, std::ostream& os) { // arg 4: The offset of the element to store in the local buffer. // arg 5: pointer to the shared memory buffer to load. // arg 6: The offset of the start element of the row to load in shared memory. - ICHECK_EQ(op->args.size(), 7U); + TVM_FFI_ICHECK_EQ(op->args.size(), 7U); bool trans = Downcast(op->args[0])->value; int num = Downcast(op->args[1])->value; std::string type = Downcast(op->args[2])->value; @@ -1084,7 +1088,7 @@ void CodeGenCUDA::VisitExpr_(const CallNode* op, std::ostream& os) { // Since ldmatrix assumes that a matrix element is 16 bit, it cannot properly transpose an // int8 matrix. std::string smem_stride = this->PrintExpr(op->args[6]); - ICHECK(num == 4); + TVM_FFI_ICHECK(num == 4); os << "for (int i = 0; i < 16; ++i) {\n"; os << local_ptr << "[" + local_elem_offset + " + i] = " << smem_ptr << "[(i % 8) / 4 * " + smem_stride + " * 16 + (threadIdx.x % 4) * 4 * " + smem_stride + @@ -1104,7 +1108,7 @@ void CodeGenCUDA::VisitExpr_(const CallNode* op, std::ostream& os) { std::string src_offset = this->PrintExpr(op->args[4]); PrimExpr stride = op->args[5]; - ICHECK(m == 16 && n == 16) << "Only m == 16 && n == 16 case supported for now"; + TVM_FFI_ICHECK(m == 16 && n == 16) << "Only m == 16 && n == 16 case supported for now"; // Each thread in a warp holds a certain number of elements of an MMA output. // For example, if we compute a 16x16 tile using MMA, each thread holds 8 elements @@ -1116,7 +1120,7 @@ void CodeGenCUDA::VisitExpr_(const CallNode* op, std::ostream& os) { const auto index_map_func = tvm::ffi::Function::GetGlobal("tir.index_map.shared_16x16_to_ldmatrix_32x8_layout"); - ICHECK(index_map_func.has_value()); + TVM_FFI_ICHECK(index_map_func.has_value()); arith::Analyzer analyzer; auto inverse_index_map = @@ -1176,7 +1180,7 @@ void CodeGenCUDA::VisitExpr_(const CallNode* op, std::ostream& os) { std::string src_offset = this->PrintExpr(op->args[3]); std::string size = this->PrintExpr(op->args[4]); int barrier_id = Downcast(op->args[5])->value; - CHECK(barrier_id < barrier_count_); + TVM_FFI_ICHECK(barrier_id < barrier_count_); std::string barrier = barrier_name_ + "[" + std::to_string(barrier_id) + "]"; this->stream << PrintCpAsyncBulkAsm(dst, dst_offset, src, src_offset, size, barrier); } else if (op->op.same_as(builtin::ptx_commit_group())) { @@ -1187,40 +1191,40 @@ void CodeGenCUDA::VisitExpr_(const CallNode* op, std::ostream& os) { } else if (op->op.same_as(builtin::ptx_cp_async_barrier())) { need_cast_smem_ptr_to_int_ = true; int barrier_id = Downcast(op->args[0])->value; - CHECK(barrier_id < barrier_count_); + TVM_FFI_ICHECK(barrier_id < barrier_count_); std::string barrier = barrier_name_ + "[" + std::to_string(barrier_id) + "]"; this->stream << PrintCpAsyncBarrierAsm(barrier); } else if (op->op.same_as(builtin::ptx_init_barrier_thread_count())) { need_cast_smem_ptr_to_int_ = true; int barrier_id = Downcast(op->args[0])->value; - CHECK(barrier_id < barrier_count_); + TVM_FFI_ICHECK(barrier_id < barrier_count_); std::string barrier = barrier_name_ + "[" + std::to_string(barrier_id) + "]"; std::string thread_count = this->PrintExpr(op->args[1]); this->stream << PrintInitBarrierThreadCountAsm(barrier, thread_count); } else if (op->op.same_as(builtin::ptx_arrive_barrier())) { need_cast_smem_ptr_to_int_ = true; int barrier_id = Downcast(op->args[0])->value; - CHECK(barrier_id < barrier_count_); + TVM_FFI_ICHECK(barrier_id < barrier_count_); std::string barrier = barrier_name_ + "[" + std::to_string(barrier_id) + "]"; this->stream << PrintArriveBarrierAsm(barrier); } else if (op->op.same_as(builtin::ptx_arrive_barrier_expect_tx())) { need_cast_smem_ptr_to_int_ = true; int barrier_id = Downcast(op->args[0])->value; - CHECK(barrier_id < barrier_count_); + TVM_FFI_ICHECK(barrier_id < barrier_count_); std::string barrier = barrier_name_ + "[" + std::to_string(barrier_id) + "]"; std::string byte_count = this->PrintExpr(op->args[1]); this->stream << PrintArriveBarrierExpectTxAsm(barrier, byte_count); } else if (op->op.same_as(builtin::ptx_wait_barrier())) { need_cast_smem_ptr_to_int_ = true; int barrier_id = Downcast(op->args[0])->value; - CHECK(barrier_id < barrier_count_); + TVM_FFI_ICHECK(barrier_id < barrier_count_); std::string barrier = barrier_name_ + "[" + std::to_string(barrier_id) + "]"; this->stream << PrintWaitBarrierAsm(barrier); } else if (op->op.same_as(builtin::create_barriers())) { - CHECK_EQ(barrier_count_, -1); + TVM_FFI_ICHECK_EQ(barrier_count_, -1); int barrier_count = Downcast(op->args[0])->value; // pad barrier alignment to avoid runtime alignment errors - CHECK_EQ(barrier_alignment_bytes_ % sizeof(uint64_t), 0); + TVM_FFI_ICHECK_EQ(barrier_alignment_bytes_ % sizeof(uint64_t), 0); int barrier_alignment_count = barrier_alignment_bytes_ / sizeof(uint64_t); if (barrier_count % barrier_alignment_count != 0) { barrier_count = ((barrier_count / barrier_alignment_count) + 1) * barrier_alignment_count; @@ -1274,10 +1278,10 @@ void CodeGenCUDA::VisitExpr_(const CallNode* op, std::ostream& os) { tgt_dtype.lanes() * tgt_dtype.bits() == src_dtype.lanes() * src_dtype.bits()) { return CodeGenC::VisitExpr_(op, os); } - CHECK_EQ(tgt_dtype.lanes(), src_dtype.lanes()) + TVM_FFI_ICHECK_EQ(tgt_dtype.lanes(), src_dtype.lanes()) << "E2M1 float4 reinterpret expects source and target to have the same number of lanes. " << "Source dtype: " << src_dtype << ", Target dtype: " << tgt_dtype; - CHECK_EQ(tgt_dtype.bytes(), src_dtype.bytes()) + TVM_FFI_ICHECK_EQ(tgt_dtype.bytes(), src_dtype.bytes()) << "E2M1 float4 reinterpret expects source and target to have the same number of bytes. " << "Source dtype: " << src_dtype << ", Target dtype: " << tgt_dtype; @@ -1334,7 +1338,8 @@ void CodeGenCUDA::VisitExpr_(const CallNode* op, std::ostream& os) { } os << PrintExpr(tir::Call(tgt_dtype, tir::builtin::reinterpret(), {value})); } else { - LOG(FATAL) << "Invalid number of lanes for float4_e2m1fn reinterpret: " << lanes; + TVM_FFI_THROW(InternalError) + << "Invalid number of lanes for float4_e2m1fn reinterpret: " << lanes; } EndScope(ssa_scope); } else if (op->op.same_as(builtin::thread_return())) { @@ -1355,7 +1360,8 @@ void CodeGenCUDA::VisitStmt_(const AttrStmtNode* op) { fragment_layouts[buffer] = layout_str->value; } else if (op->attr_key == tir::attr::async_commit_queue_scope) { const IntImmNode* queue_id = op->value.as(); - ICHECK(queue_id && queue_id->value == 0) << "For CUDA, the index of an async queue must be 0."; + TVM_FFI_ICHECK(queue_id && queue_id->value == 0) + << "For CUDA, the index of an async queue must be 0."; this->VisitStmt(op->body); auto commit_group = Call(DataType::Void(), builtin::ptx_commit_group(), {}); this->VisitExpr(commit_group, this->stream); @@ -1363,12 +1369,13 @@ void CodeGenCUDA::VisitStmt_(const AttrStmtNode* op) { } else if (op->attr_key == tir::attr::async_wait_queue_scope) { auto wait_attrs = GetAsyncWaitAttributes(op); auto queue_id = wait_attrs.first.as(); - ICHECK(queue_id && queue_id->value == 0) << "For CUDA, the index of an async queue must be 0."; + TVM_FFI_ICHECK(queue_id && queue_id->value == 0) + << "For CUDA, the index of an async queue must be 0."; auto wait_cnt = wait_attrs.second; auto wait_group = Call(DataType::Void(), builtin::ptx_wait_group(), {wait_cnt}); this->VisitExpr(wait_group, this->stream); auto inner = op->body.as(); - ICHECK(inner); + TVM_FFI_ICHECK(inner); this->VisitStmt(inner->body); return; } @@ -1376,7 +1383,7 @@ void CodeGenCUDA::VisitStmt_(const AttrStmtNode* op) { } void CodeGenCUDA::VisitStmt_(const AllocateNode* op) { - ICHECK(!is_zero(op->condition)); + TVM_FFI_ICHECK(!is_zero(op->condition)); std::string vid = AllocVarID(op->buffer_var.get()); this->PrintIndent(); @@ -1384,15 +1391,15 @@ void CodeGenCUDA::VisitStmt_(const AllocateNode* op) { const VarNode* buffer = op->buffer_var.as(); if (scope.find("wmma.") == 0) { if (scope == "wmma.matrix_a" || scope == "wmma.matrix_b") { - ICHECK(op->dtype == DataType::Float(16) || op->dtype == DataType::Int(8) || - op->dtype == DataType::UInt(8) || op->dtype == DataType::Int(4) || - op->dtype == DataType::UInt(4) || op->dtype == DataType::Int(1) || - op->dtype == DataType::BFloat(16)) + TVM_FFI_ICHECK(op->dtype == DataType::Float(16) || op->dtype == DataType::Int(8) || + op->dtype == DataType::UInt(8) || op->dtype == DataType::Int(4) || + op->dtype == DataType::UInt(4) || op->dtype == DataType::Int(1) || + op->dtype == DataType::BFloat(16)) << "Matrix_a and matrix_b only support half or char or unsigned char " << "or uint4 or int4 or int1 type for now"; } else { - ICHECK(op->dtype == DataType::Float(16) || op->dtype == DataType::Float(32) || - op->dtype == DataType::Int(32)) + TVM_FFI_ICHECK(op->dtype == DataType::Float(16) || op->dtype == DataType::Float(32) || + op->dtype == DataType::Int(32)) << "Accumulator only support half, float and int type for now"; } PrintWmmaScope(scope, op->dtype, buffer, stream); @@ -1405,7 +1412,7 @@ void CodeGenCUDA::VisitStmt_(const AllocateNode* op) { stream << ' ' << vid << "[];\n"; } else { size_t constant_size = op->ConstantAllocationSize(); - ICHECK_GT(constant_size, 0) << "Can only handle constant size stack allocation for now"; + TVM_FFI_ICHECK_GT(constant_size, 0) << "Can only handle constant size stack allocation for now"; if (scope.find("wmma.") == 0) { constant_size = GetWmmaFragmentSize(scope, buffer, constant_size); @@ -1441,7 +1448,7 @@ void CodeGenCUDA::VisitStmt_(const EvaluateNode* op) { void CodeGenCUDA::VisitExpr_(const RampNode* op, std::ostream& os) { int lanes = op->dtype.lanes(); - CHECK_LE(lanes, 4) << "ValueError: Ramp of more than 4 lanes is not allowed."; + TVM_FFI_CHECK_LE(lanes, 4, ValueError) << "Ramp of more than 4 lanes is not allowed."; PrintVecConstructor(op->dtype, os); os << "("; for (int i = 0; i < lanes; i++) { @@ -1457,7 +1464,7 @@ void CodeGenCUDA::VisitExpr_(const BroadcastNode* op, std::ostream& os) { // NO if ((op->dtype.is_int() || op->dtype.is_uint()) && op->dtype.bits() == 8 && lanes == 4) { // make_int8x4 const int64_t* p = as_const_int(op->value); - ICHECK(p); + TVM_FFI_ICHECK(p); int64_t v = *p & 0xFF; v = (v << 24) | (v << 16) | (v << 8) | v; if (op->dtype.is_uint()) { @@ -1508,7 +1515,7 @@ void CodeGenCUDA::VisitExpr_(const BroadcastNode* op, std::ostream& os) { // NO if (op->dtype.is_float8() || op->dtype.is_float4()) { int lanes = op->dtype.lanes(); - ICHECK(lanes == 1 || lanes == 2 || lanes == 4); + TVM_FFI_ICHECK(lanes == 1 || lanes == 2 || lanes == 4); std::string v = PrintExpr(op->value); // Implicit conversion from float back to fp8 PrintType(op->dtype, os); @@ -1524,7 +1531,7 @@ void CodeGenCUDA::VisitExpr_(const BroadcastNode* op, std::ostream& os) { // NO if ((op->dtype.is_int() || op->dtype.is_uint()) && op->dtype.bits() == 4) { bool fail = false; const int64_t* p = as_const_int(op->value); - ICHECK(p); + TVM_FFI_ICHECK(p); int64_t v = *p & 0xF; if (lanes == 4) { @@ -1582,8 +1589,8 @@ void CodeGenCUDA::VisitExpr_(const SelectNode* op, std::ostream& os) { } // Codegen vector condition case by serializing the select op. - ICHECK(op->false_value->dtype == op->dtype && op->true_value->dtype == op->dtype && - op->dtype.lanes() == op->condition.dtype().lanes()); + TVM_FFI_ICHECK(op->false_value->dtype == op->dtype && op->true_value->dtype == op->dtype && + op->dtype.lanes() == op->condition.dtype().lanes()); std::string r_var = name_supply_->FreshName("_"); this->PrintIndent(); @@ -1677,7 +1684,7 @@ inline void PrintConst(const FloatImmNode* op, std::ostream& os, CodeGenCUDA* p) break; } default: - LOG(FATAL) << "Bad bit-width for float: " << op->dtype << "\n"; + TVM_FFI_THROW(InternalError) << "Bad bit-width for float: " << op->dtype << "\n"; } } @@ -1689,7 +1696,7 @@ void CodeGenCUDA::PrintWmmaScope(const std::string& scope, DataType t, const Var std::ostream& os) { std::stringstream type; PrintType(t, type); - ICHECK(fragment_shapes.count(variable)) + TVM_FFI_ICHECK(fragment_shapes.count(variable)) << "Cannot find shape of the wmma fragment " << variable->name_hint; std::string shape_str = fragment_shapes.at(variable); if ((t.is_int() || t.is_uint()) && t.bits() < 8 && t.lanes() == 1) { @@ -1700,26 +1707,26 @@ void CodeGenCUDA::PrintWmmaScope(const std::string& scope, DataType t, const Var } else if (t.bits() == 1) { type << "nvcuda::wmma::experimental::precision::b1"; } else { - LOG(FATAL) << "Unhandled interger type for wmma fragment!"; + TVM_FFI_THROW(InternalError) << "Unhandled interger type for wmma fragment!"; } } else if (t.is_uint()) { if (t.bits() == 4) { type << "nvcuda::wmma::experimental::precision::u4"; } else { - LOG(FATAL) << "Unhandled interger type for wmma fragment!"; + TVM_FFI_THROW(InternalError) << "Unhandled interger type for wmma fragment!"; } } } if (scope == "wmma.matrix_a") { need_mma_h_ = true; std::string layout_str = fragment_layouts[variable]; - ICHECK_NE(layout_str, "") << "Layout must be defined for matrix_a"; + TVM_FFI_ICHECK_NE(layout_str, "") << "Layout must be defined for matrix_a"; os << "nvcuda::wmma::fragment"; } else if (scope == "wmma.matrix_b") { need_mma_h_ = true; std::string layout_str = fragment_layouts[variable]; - ICHECK_NE(layout_str, "") << "Layout must be defined for matrix_b"; + TVM_FFI_ICHECK_NE(layout_str, "") << "Layout must be defined for matrix_b"; os << "nvcuda::wmma::fragment"; } else if (scope == "wmma.accumulator") { @@ -1733,14 +1740,14 @@ int stoi(const std::string& str) { try { return std::stoi(str); } catch (std::invalid_argument& e) { - LOG(FATAL) << "Cannot convert \"" << str << "\" to int"; + TVM_FFI_THROW(InternalError) << "Cannot convert \"" << str << "\" to int"; throw; } } int32_t CodeGenCUDA::GetWmmaFragmentSize(const std::string& scope, const VarNode* variable, int32_t size) { - ICHECK(fragment_shapes.count(variable)) + TVM_FFI_ICHECK(fragment_shapes.count(variable)) << "Cannot find shape of the wmma fragment " << variable->name_hint; std::string shape_str = fragment_shapes.at(variable); std::pair dim = GetWmmaFragmentDimSize(shape_str, scope); @@ -1766,7 +1773,7 @@ void CodeGenCUDA::HandleVolatileLoads(const std::string& value, const BufferLoad void CodeGenCUDA::PrintVecElemLoadExpr(DataType t, int i, const std::string& value, std::ostream& os) { - ICHECK_GT(t.lanes(), 1); + TVM_FFI_ICHECK_GT(t.lanes(), 1); if (t.bits() == 8 && (t.is_int() || t.is_uint())) { if (!(t.lanes() == 2 || t.lanes() == 3)) { if (i != 0) { diff --git a/src/target/source/codegen_metal.cc b/src/target/source/codegen_metal.cc index 01042776c971..8dfdd977accb 100644 --- a/src/target/source/codegen_metal.cc +++ b/src/target/source/codegen_metal.cc @@ -78,7 +78,7 @@ void CodeGenMetal::AddFunction(const GlobalVar& gvar, const PrimFunc& func) { // add to alloc buffer type. auto global_symbol = func->GetAttr(tvm::attr::kGlobalSymbol); - ICHECK(global_symbol.has_value()) + TVM_FFI_ICHECK(global_symbol.has_value()) << "CodeGenC: Expect PrimFunc to have the global_symbol attribute"; // Function header. @@ -122,7 +122,7 @@ void CodeGenMetal::AddFunction(const GlobalVar& gvar, const PrimFunc& func) { decl_stream << "struct " << arg_buf_type << " {\n"; for (size_t i = num_buffer; i < func->params.size(); ++i) { Var v = func->params[i]; - ICHECK(!v.dtype().is_handle()); + TVM_FFI_ICHECK(!v.dtype().is_handle()); std::string vid = AllocVarID(v.get()); std::ostringstream vref; if (v.dtype().bits() == 32) { @@ -146,8 +146,8 @@ void CodeGenMetal::AddFunction(const GlobalVar& gvar, const PrimFunc& func) { decl_stream << "};\n\n"; } // Setup the thread group info. - ICHECK_EQ(name_supply_->FreshName("threadIdx"), "threadIdx"); - ICHECK_EQ(name_supply_->FreshName("blockIdx"), "blockIdx"); + TVM_FFI_ICHECK_EQ(name_supply_->FreshName("threadIdx"), "threadIdx"); + TVM_FFI_ICHECK_EQ(name_supply_->FreshName("blockIdx"), "blockIdx"); int work_dim = 0; auto launch_params = func->GetAttr>(tir::attr::kKernelLaunchParams).value(); @@ -179,7 +179,7 @@ void CodeGenMetal::AddFunction(const GlobalVar& gvar, const PrimFunc& func) { } void CodeGenMetal::BindThreadIndex(const IterVar& iv) { - ICHECK(!var_idmap_.count(iv->var.get())); + TVM_FFI_ICHECK(!var_idmap_.count(iv->var.get())); // if we only have threadIdx.x // metal will directly print as threadIdx std::string vname = iv->thread_tag; @@ -193,7 +193,7 @@ void CodeGenMetal::BindThreadIndex(const IterVar& iv) { void CodeGenMetal::PrintType(DataType t, std::ostream& os) { // NOLINT(*) int lanes = t.lanes(); if (t.is_handle()) { - ICHECK_EQ(lanes, 1) << "do not yet support vector types"; + TVM_FFI_ICHECK_EQ(lanes, 1) << "do not yet support vector types"; os << "void*"; return; } @@ -268,7 +268,7 @@ void CodeGenMetal::PrintType(DataType t, std::ostream& os) { // NOLINT(*) os << "bfloat"; return; } - LOG(FATAL) << "Cannot convert type " << t << " to Metal type"; + TVM_FFI_THROW(InternalError) << "Cannot convert type " << t << " to Metal type"; } void CodeGenMetal::PrintStorageSync(const CallNode* op) { @@ -280,7 +280,7 @@ void CodeGenMetal::PrintStorageSync(const CallNode* op) { this->PrintIndent(); this->stream << "threadgroup_barrier(mem_flags::mem_threadgroup);\n"; } else if (sync == "global") { - LOG(FATAL) << "global barrier not supported"; + TVM_FFI_THROW(InternalError) << "global barrier not supported"; } } @@ -304,25 +304,25 @@ void CodeGenMetal::PrintStorageScope(const std::string& scope, std::ostream& os) } else if (scope == "local") { os << "thread "; } else { - LOG(FATAL) << "Unknown storage scope `" << scope << "`"; + TVM_FFI_THROW(InternalError) << "Unknown storage scope `" << scope << "`"; } } void CodeGenMetal::VisitStmt_(const AllocateNode* op) { - ICHECK(!is_zero(op->condition)); + TVM_FFI_ICHECK(!is_zero(op->condition)); std::string vid = AllocVarID(op->buffer_var.get()); this->PrintIndent(); size_t constant_size = op->ConstantAllocationSize(); - ICHECK_GT(constant_size, 0) << "Can only handle constant size stack allocation for now"; + TVM_FFI_ICHECK_GT(constant_size, 0) << "Can only handle constant size stack allocation for now"; auto scope = GetPtrStorageScope(op->buffer_var); alloc_storage_scope_[op->buffer_var.get()] = scope; if (scope == "metal.simdgroup") { - ICHECK(op->dtype == DataType::Float(16) || op->dtype == DataType::Float(32) || - op->dtype == DataType::BFloat(16)) + TVM_FFI_ICHECK(op->dtype == DataType::Float(16) || op->dtype == DataType::Float(32) || + op->dtype == DataType::BFloat(16)) << "Only float16, float32, and bfloat16 are supported, but got " << op->dtype; - ICHECK(constant_size % 64 == 0) + TVM_FFI_ICHECK(constant_size % 64 == 0) << "Only 8x8 matrix is supported, but got " << constant_size << " bytes\n"; std::ostringstream dtype_os; @@ -358,23 +358,23 @@ void CodeGenMetal::VisitExpr_(const BroadcastNode* op, std::ostream& os) { // N } void CodeGenMetal::VisitExpr_(const CallNode* op, std::ostream& os) { // NOLINT(*) - CHECK(!op->op.as()) + TVM_FFI_ICHECK(!op->op.as()) << "CodegenMetal does not support inter-function calls, " << "but expression " << ffi::GetRef(op) << " calls PrimFunc " << op->op; auto f_check_simdgroup_shape = [](PrimExpr col, PrimExpr row) { - ICHECK(col->IsInstance() && row->IsInstance()) + TVM_FFI_ICHECK(col->IsInstance() && row->IsInstance()) << "Only constant shape is supported for simdgroup matrix, but got " << col << "x" << row; int col_val = col.as()->value; int row_val = row.as()->value; - ICHECK(col_val == 8 && row_val == 8) + TVM_FFI_ICHECK(col_val == 8 && row_val == 8) << "Only 8x8 matrix is supported, but got " << col_val << "x" << row_val; }; if (op->op.same_as(builtin::make_filled_simdgroup_matrix())) { - ICHECK_EQ(op->args.size(), 5); + TVM_FFI_ICHECK_EQ(op->args.size(), 5); Var var = Downcast(op->args[0]); // Get the data type of the simdgroup matrix auto it = simdgroup_dtype_.find(var.get()); - ICHECK(it != simdgroup_dtype_.end()) + TVM_FFI_ICHECK(it != simdgroup_dtype_.end()) << "Cannot find variable allocation for simdgroup: " << var; const std::string& dtype_str = it->second; f_check_simdgroup_shape(op->args[3], op->args[4]); @@ -382,19 +382,19 @@ void CodeGenMetal::VisitExpr_(const CallNode* op, std::ostream& os) { // NOLINT << dtype_str << ", " << PrintExpr(op->args[3]) << ", " << PrintExpr(op->args[4]) << ">(" << PrintExpr(op->args[2]) << ")"; } else if (op->op.same_as(builtin::simdgroup_load())) { - ICHECK_EQ(op->args.size(), 7); + TVM_FFI_ICHECK_EQ(op->args.size(), 7); f_check_simdgroup_shape(op->args[4], op->args[5]); os << "simdgroup_load(" << PrintExpr(op->args[0]) << "[" << PrintExpr(op->args[1]) << "], " << PrintExpr(op->args[2]) << ", " << PrintExpr(op->args[3]) << ", 0, " << PrintExpr(op->args[6]) << ")"; } else if (op->op.same_as(builtin::simdgroup_store())) { - ICHECK_EQ(op->args.size(), 7); + TVM_FFI_ICHECK_EQ(op->args.size(), 7); f_check_simdgroup_shape(op->args[4], op->args[5]); os << "simdgroup_store(" << PrintExpr(op->args[0]) << "[" << PrintExpr(op->args[1]) << "], " << PrintExpr(op->args[2]) << ", " << PrintExpr(op->args[3]) << ", 0, " << PrintExpr(op->args[6]) << ")"; } else if (op->op.same_as(builtin::simdgroup_multiply_accumulate())) { - ICHECK_EQ(op->args.size(), 8); + TVM_FFI_ICHECK_EQ(op->args.size(), 8); os << "simdgroup_multiply_accumulate(" // << PrintExpr(op->args[0]) << "[" << PrintExpr(op->args[1]) << "], " // << PrintExpr(op->args[2]) << "[" << PrintExpr(op->args[3]) << "], " // @@ -442,9 +442,9 @@ ffi::Module BuildMetal(IRModule mod, Target target) { std::string fmt = fmetal_compile ? "metallib" : "metal"; for (auto kv : mod->functions) { - ICHECK(kv.second->IsInstance()) << "CodeGenMetal: Can only take PrimFunc"; + TVM_FFI_ICHECK(kv.second->IsInstance()) << "CodeGenMetal: Can only take PrimFunc"; auto global_symbol = kv.second->GetAttr(tvm::attr::kGlobalSymbol); - ICHECK(global_symbol.has_value()); + TVM_FFI_ICHECK(global_symbol.has_value()); std::string func_name = global_symbol.value(); source_maker << "// Function: " << func_name << "\n"; @@ -452,7 +452,7 @@ ffi::Module BuildMetal(IRModule mod, Target target) { cg.Init(output_ssa); auto f = Downcast(kv.second); auto calling_conv = f->GetAttr(tvm::attr::kCallingConv); - ICHECK(calling_conv == CallingConv::kDeviceKernelLaunch) + TVM_FFI_ICHECK(calling_conv == CallingConv::kDeviceKernelLaunch) << "CodeGenMetal: expect calling_conv equals CallingConv::kDeviceKernelLaunch"; cg.AddFunction(kv.first, f); diff --git a/src/target/source/codegen_opencl.cc b/src/target/source/codegen_opencl.cc index e213be519c16..4a7e4a667d13 100644 --- a/src/target/source/codegen_opencl.cc +++ b/src/target/source/codegen_opencl.cc @@ -181,7 +181,7 @@ std::string CodeGenOpenCL::Finish() { } void CodeGenOpenCL::BindThreadIndex(const IterVar& iv) { - ICHECK(!var_idmap_.count(iv->var.get())); + TVM_FFI_ICHECK(!var_idmap_.count(iv->var.get())); runtime::ThreadScope ts = runtime::ThreadScope::Create(iv->thread_tag); std::ostringstream os; if (ts.rank == 1) { @@ -195,7 +195,7 @@ void CodeGenOpenCL::BindThreadIndex(const IterVar& iv) { void CodeGenOpenCL::PrintType(DataType t, std::ostream& os) { // NOLINT(*) int lanes = t.lanes(); if (t.is_handle()) { - ICHECK_EQ(lanes, 1) << "do not yet support vector types"; + TVM_FFI_ICHECK_EQ(lanes, 1) << "do not yet support vector types"; os << "void*"; return; } @@ -266,7 +266,7 @@ void CodeGenOpenCL::PrintType(DataType t, std::ostream& os) { // NOLINT(*) return; } } - LOG(FATAL) << "Cannot convert type " << t << " to OpenCL type"; + TVM_FFI_THROW(InternalError) << "Cannot convert type " << t << " to OpenCL type"; } void CodeGenOpenCL::PrintType(const Type& type, std::ostream& os) { // NOLINT(*) @@ -282,7 +282,7 @@ void CodeGenOpenCL::PrintType(const Type& type, std::ostream& os) { // NOLINT(* } else if (IsVoidType(type)) { os << "void"; } else { - LOG(FATAL) << "Type " << type << " does not have a corresponding C Type"; + TVM_FFI_THROW(InternalError) << "Type " << type << " does not have a corresponding C Type"; } } @@ -319,7 +319,7 @@ void CodeGenOpenCL::PrintVecStore(const BufferNode* buffer, DataType t, PrimExpr void CodeGenOpenCL::PrintVecElemLoadExpr(DataType t, int i, const std::string& value, std::ostream& os) { // NOLINT(*) - ICHECK_GT(t.lanes(), 1); + TVM_FFI_ICHECK_GT(t.lanes(), 1); if (t.bits() == 8 && (t.is_int() || t.is_uint())) { if (i != 0) { os << "|"; @@ -351,7 +351,7 @@ void CodeGenOpenCL::PrintStorageSync(const CallNode* op) { this->PrintIndent(); this->stream << "barrier(CLK_LOCAL_MEM_FENCE);\n"; } else if (sync == "global") { - LOG(FATAL) << "not supported"; + TVM_FFI_THROW(InternalError) << "not supported"; } } @@ -407,8 +407,9 @@ void CodeGenOpenCL::VisitExpr_(const CallNode* op, std::ostream& os) { if (op->op.same_as(builtin::address_of())) { // Overload tvm_address_of to add storage scope (e.g. __global). const BufferLoadNode* load = op->args[0].as(); - ICHECK(op->args.size() == 1 && load); - ICHECK_EQ(load->indices.size(), 1) << "CodeGenOpenCL only supports flat memory allocations."; + TVM_FFI_ICHECK(op->args.size() == 1 && load); + TVM_FFI_ICHECK_EQ(load->indices.size(), 1) + << "CodeGenOpenCL only supports flat memory allocations."; os << "(("; auto it = alloc_storage_scope_.find(load->buffer->data.get()); if (it != alloc_storage_scope_.end()) { @@ -420,11 +421,11 @@ void CodeGenOpenCL::VisitExpr_(const CallNode* op, std::ostream& os) { os << ')'; } else if (op->op.same_as(builtin::texture2d_store())) { auto* ptr_type = op->args[0].as()->type_annotation.as(); - ICHECK(ptr_type != nullptr) << "Texture Var's must be of PointerType"; - ICHECK(runtime::IsTextureStorage(std::string(ptr_type->storage_scope))) + TVM_FFI_ICHECK(ptr_type != nullptr) << "Texture Var's must be of PointerType"; + TVM_FFI_ICHECK(runtime::IsTextureStorage(std::string(ptr_type->storage_scope))) << "builtin::texture2d_store() only supports storing to texture buffers"; const int channel_size = Downcast(op->args[4])->value; - ICHECK(channel_size == 64 || channel_size == 128) + TVM_FFI_ICHECK(channel_size == 64 || channel_size == 128) << "Unsupported Channel Size: " << channel_size; DataType channel_type = runtime::GetChannelType(channel_size); @@ -438,7 +439,7 @@ void CodeGenOpenCL::VisitExpr_(const CallNode* op, std::ostream& os) { } else if (channel_size == 128) { os << "write_imagef("; } else { - LOG(FATAL) << "Unsupported Channel Size: " << channel_size; + TVM_FFI_THROW(InternalError) << "Unsupported Channel Size: " << channel_size; } this->PrintExpr(op->args[0], os); os << ", "; @@ -460,7 +461,7 @@ void CodeGenOpenCL::VisitExpr_(const CallNode* op, std::ostream& os) { std::stringstream ss; const int channel_size = Downcast(op->args[4])->value; const int data_lanes = channel_size / op->dtype.bits(); - ICHECK(channel_size == 64 || channel_size == 128) + TVM_FFI_ICHECK(channel_size == 64 || channel_size == 128) << "Unsupported Channel Size: " << channel_size; ss << "as_"; this->PrintType(op->dtype.with_lanes(data_lanes), ss); @@ -470,7 +471,7 @@ void CodeGenOpenCL::VisitExpr_(const CallNode* op, std::ostream& os) { } else if (channel_size == 128) { ss << "READ_IMAGEF("; } else { - LOG(FATAL) << "Unsupported Channel Size: " << channel_size; + TVM_FFI_THROW(InternalError) << "Unsupported Channel Size: " << channel_size; } this->PrintExpr(op->args[0], ss); ss << ", "; @@ -500,7 +501,7 @@ void CodeGenOpenCL::VisitExpr_(const CallNode* op, std::ostream& os) { this->PrintExpr(ramp->base, os); os << "))"; } else { - LOG(FATAL) << "Unsupported Texture Load Args"; + TVM_FFI_THROW(InternalError) << "Unsupported Texture Load Args"; } } else { os << "(("; @@ -598,8 +599,8 @@ void CodeGenOpenCL::VisitExpr_(const ModNode* op, std::ostream& os) { // NOLINT if (op->dtype.is_int() || op->dtype.is_uint()) { opstr = "%"; } else { - ICHECK(op->dtype.is_float()) << "Expected floating point or integer dtype in Mod, but got " - << op->dtype; + TVM_FFI_ICHECK(op->dtype.is_float()) + << "Expected floating point or integer dtype in Mod, but got " << op->dtype; opstr = "fmod"; } if (op->dtype.lanes() == 1) { @@ -685,10 +686,11 @@ ffi::Module BuildOpenCL(IRModule mod, Target target) { ffi::Map functions; for (auto [gvar, base_func] : mod->functions) { - ICHECK(base_func->IsInstance()) << "CodeGenOpenCL: Can only take PrimFunc"; + TVM_FFI_ICHECK(base_func->IsInstance()) + << "CodeGenOpenCL: Can only take PrimFunc"; auto prim_func = Downcast(base_func); auto calling_conv = prim_func->GetAttr(tvm::attr::kCallingConv); - ICHECK(calling_conv == CallingConv::kDeviceKernelLaunch) + TVM_FFI_ICHECK(calling_conv == CallingConv::kDeviceKernelLaunch) << "CodeGenOpenCL: expect calling_conv equals CallingConv::kDeviceKernelLaunch"; functions.Set(gvar, prim_func); } diff --git a/src/target/source/codegen_params.cc b/src/target/source/codegen_params.cc index d840ebec7df3..ae915f278f57 100644 --- a/src/target/source/codegen_params.cc +++ b/src/target/source/codegen_params.cc @@ -163,8 +163,8 @@ void PrintFloatingPointArray(void* data, size_t num_elements, int indent_chars, void TensorDataToC(::tvm::runtime::Tensor arr, int indent_chars, std::ostream& os, const std::string& eol) { auto arr_type = arr.DataType(); - CHECK_EQ(arr_type.lanes(), 1) << "CodegenParams: only support generating 1-lane parameters; saw " - << arr_type.lanes(); + TVM_FFI_ICHECK_EQ(arr_type.lanes(), 1) + << "CodegenParams: only support generating 1-lane parameters; saw " << arr_type.lanes(); auto shape = arr.Shape(); int num_elements = 1; @@ -178,8 +178,8 @@ void TensorDataToC(::tvm::runtime::Tensor arr, int indent_chars, std::ostream& o os.fill('0'); switch (arr_type.code()) { case runtime::DataType::kInt: - CHECK(arr_type.bits() == 8 || arr_type.bits() == 16 || arr_type.bits() == 32 || - arr_type.bits() == 64) + TVM_FFI_ICHECK(arr_type.bits() == 8 || arr_type.bits() == 16 || arr_type.bits() == 32 || + arr_type.bits() == 64) << "CodegenParams: only support generating 8-, 16-, 32-, or 64-bit integer params; saw " << arr_type.bits() << "-bit array"; if (arr_type.bits() == 8) { @@ -191,13 +191,13 @@ void TensorDataToC(::tvm::runtime::Tensor arr, int indent_chars, std::ostream& o } else if (arr_type.bits() == 64) { PrintIntegralArray(arr->data, num_elements, indent_chars, os, eol); } else { - CHECK(false) << "should not get here"; + TVM_FFI_ICHECK(false) << "should not get here"; } break; case runtime::DataType::TypeCode::kUInt: - CHECK(arr_type.bits() == 8 || arr_type.bits() == 16 || arr_type.bits() == 32 || - arr_type.bits() == 64) + TVM_FFI_ICHECK(arr_type.bits() == 8 || arr_type.bits() == 16 || arr_type.bits() == 32 || + arr_type.bits() == 64) << "CodegenParams: only support generating 8-, 16-, 32-, or 64-bit integer params; saw " << arr_type.bits() << "-bit array"; @@ -210,7 +210,7 @@ void TensorDataToC(::tvm::runtime::Tensor arr, int indent_chars, std::ostream& o } else if (arr_type.bits() == 64) { PrintIntegralArray(arr->data, num_elements, indent_chars, os, eol); } else { - CHECK(false) << "should not get here"; + TVM_FFI_ICHECK(false) << "should not get here"; } break; @@ -225,15 +225,15 @@ void TensorDataToC(::tvm::runtime::Tensor arr, int indent_chars, std::ostream& o } else if (arr_type.bits() == 64) { PrintFloatingPointArray(arr->data, num_elements, indent_chars, os, eol); } else { - CHECK(false) << "CodegenParams: only support 32- or 64-bit floating point; saw " - << arr_type.bits() << "-bit array"; + TVM_FFI_ICHECK(false) << "CodegenParams: only support 32- or 64-bit floating point; saw " + << arr_type.bits() << "-bit array"; } break; } case runtime::DataType::TypeCode::kBFloat: { // NOTE: print types not widely supported by C as uint16_t. - CHECK(arr_type.bits() == 16) + TVM_FFI_ICHECK(arr_type.bits() == 16) << "CodegenParams: only support generating 16-bit bfloat params; saw " << arr_type.bits() << "-bit array"; PrintIntegralArray(arr->data, num_elements, indent_chars, os, eol); @@ -241,7 +241,7 @@ void TensorDataToC(::tvm::runtime::Tensor arr, int indent_chars, std::ostream& o } default: - CHECK(false) << "Data type '" << arr_type << "' not supported"; + TVM_FFI_ICHECK(false) << "Data type '" << arr_type << "' not supported"; } os.flags(old_fmtflags); diff --git a/src/target/source/codegen_source_base.cc b/src/target/source/codegen_source_base.cc index 917036b8e2de..c0f906a34f62 100644 --- a/src/target/source/codegen_source_base.cc +++ b/src/target/source/codegen_source_base.cc @@ -53,7 +53,7 @@ std::string CodeGenSourceBase::SSAGetID(std::string src, DataType t) { } std::string CodeGenSourceBase::AllocVarID(const tir::VarNode* v) { - ICHECK(!var_idmap_.count(v)) << "Need input to be in SSA form dup " << v->name_hint; + TVM_FFI_ICHECK(!var_idmap_.count(v)) << "Need input to be in SSA form dup " << v->name_hint; std::string key = v->name_hint; std::string vid = name_supply_->FreshName(key); std::replace(vid.begin(), vid.end(), ':', '_'); @@ -65,7 +65,7 @@ std::string CodeGenSourceBase::AllocVarID(const tir::VarNode* v) { std::string CodeGenSourceBase::GetVarID(const tir::VarNode* v) const { auto it = var_idmap_.find(v); - ICHECK(it != var_idmap_.end()) << "Find undefined Variable " << v->name_hint; + TVM_FFI_ICHECK(it != var_idmap_.end()) << "Find undefined Variable " << v->name_hint; return it->second; } @@ -83,7 +83,7 @@ void CodeGenSourceBase::MarkConst(std::string vid) { e.scope_id = 0; ssa_assign_map_[vid] = e; } else { - ICHECK_EQ(it->second.vid, vid); + TVM_FFI_ICHECK_EQ(it->second.vid, vid); } } @@ -100,7 +100,7 @@ void CodeGenSourceBase::EndScope(int scope_id) { } void CodeGenSourceBase::PrintType(DataType type, std::ostream& os) { // NOLINT(*) - ICHECK_EQ(type.lanes(), 1) << "do not yet support vector types"; + TVM_FFI_ICHECK_EQ(type.lanes(), 1) << "do not yet support vector types"; if (type.is_handle()) { os << "void*"; return; @@ -147,7 +147,7 @@ void CodeGenSourceBase::PrintType(DataType type, std::ostream& os) { // NOLINT( } } } - LOG(FATAL) << "Cannot convert type " << type << " to C type"; + TVM_FFI_THROW(InternalError) << "Cannot convert type " << type << " to C type"; } void CodeGenSourceBase::PrintType(const Type& type, std::ostream& os) { // NOLINT(*) @@ -159,7 +159,7 @@ void CodeGenSourceBase::PrintType(const Type& type, std::ostream& os) { // NOLI } else if (IsVoidType(type)) { os << "void"; } else { - LOG(FATAL) << "Type " << type << " does not have a corresponding C Type"; + TVM_FFI_THROW(InternalError) << "Type " << type << " does not have a corresponding C Type"; } } diff --git a/src/target/source/codegen_webgpu.cc b/src/target/source/codegen_webgpu.cc index 5d306780922a..eb5351bd0f98 100644 --- a/src/target/source/codegen_webgpu.cc +++ b/src/target/source/codegen_webgpu.cc @@ -83,11 +83,11 @@ class WebGPUWorkgroupInfoCollector : public StmtExprVisitor { if (iv->thread_tag.length() != 0) { runtime::ThreadScope ts = runtime::ThreadScope::Create(iv->thread_tag); if (ts.rank == 1) { - ICHECK_GE(ts.dim_index, 0) << "vthread should have been optimized out by here"; - ICHECK_LT(ts.dim_index, 3); + TVM_FFI_ICHECK_GE(ts.dim_index, 0) << "vthread should have been optimized out by here"; + TVM_FFI_ICHECK_LT(ts.dim_index, 3); auto* sizeptr = op->value.as(); - ICHECK(sizeptr) << "CodeGenWebGPU: only allows constant thread group size " - << " get " << op->value; + TVM_FFI_ICHECK(sizeptr) << "CodeGenWebGPU: only allows constant thread group size " + << " get " << op->value; info_.workgroup_size[ts.dim_index] = static_cast(sizeptr->value); } else if (ts.rank == 0) { if (ts.dim_index == 2) { @@ -134,13 +134,13 @@ runtime::FunctionInfo CodeGenWebGPU::AddFunction(const PrimFunc& f, bool skip_re // skip the first underscore, so SSA variable starts from name_supply_->FreshName("v_"); // Setup the thread group info. - ICHECK_EQ(name_supply_->FreshName("threadIdx"), "threadIdx"); - ICHECK_EQ(name_supply_->FreshName("blockIdx"), "blockIdx"); - ICHECK_EQ(name_supply_->FreshName("gridDim"), "gridDim"); + TVM_FFI_ICHECK_EQ(name_supply_->FreshName("threadIdx"), "threadIdx"); + TVM_FFI_ICHECK_EQ(name_supply_->FreshName("blockIdx"), "blockIdx"); + TVM_FFI_ICHECK_EQ(name_supply_->FreshName("gridDim"), "gridDim"); // add to alloc buffer type. auto global_symbol = f->GetAttr(tvm::attr::kGlobalSymbol); - ICHECK(global_symbol.has_value()) + TVM_FFI_ICHECK(global_symbol.has_value()) << "CodeGenWebGPU: Expect PrimFunc to have the global_symbol attribute"; header_stream << "//----------------------------------------\n" @@ -165,13 +165,15 @@ runtime::FunctionInfo CodeGenWebGPU::AddFunction(const PrimFunc& f, bool skip_re if (t.is_handle()) { auto* ptr = arg->type_annotation.as(); - ICHECK(ptr) << "All handles passed to the CodeGenWebGPU must have a type_annotation as a " - "PointerType, " - << "and must point to a PrimType"; + TVM_FFI_ICHECK(ptr) + << "All handles passed to the CodeGenWebGPU must have a type_annotation as a " + "PointerType, " + << "and must point to a PrimType"; auto* prim = ptr->element_type.as(); - ICHECK(prim) << "All handles passed to the CodeGenWebGPU must have a type_annotation as a " - "PointerType, " - << "and must point to a PrimType"; + TVM_FFI_ICHECK(prim) + << "All handles passed to the CodeGenWebGPU must have a type_annotation as a " + "PointerType, " + << "and must point to a PrimType"; DataType value_storage_type = prim->dtype; if (value_storage_type == DataType::Bool()) { // We need a physically addressable buffer type to support boolean tensors. @@ -212,7 +214,7 @@ runtime::FunctionInfo CodeGenWebGPU::AddFunction(const PrimFunc& f, bool skip_re for (size_t i = 0; i < pod_args.size(); ++i) { Var v = pod_args[i]; - ICHECK(!v.dtype().is_handle()); + TVM_FFI_ICHECK(!v.dtype().is_handle()); std::string vid = AllocVarID(v.get()); if (v.dtype() == DataType::Int(32)) { @@ -222,7 +224,7 @@ runtime::FunctionInfo CodeGenWebGPU::AddFunction(const PrimFunc& f, bool skip_re } else if (v.dtype() == DataType::Float(32)) { this->decl_stream << " " << vid << ": f32"; } else { - LOG(FATAL) << "Do not support pod argument type " << v.dtype(); + TVM_FFI_THROW(InternalError) << "Do not support pod argument type " << v.dtype(); } this->decl_stream << ",\n"; // value ref @@ -244,7 +246,7 @@ runtime::FunctionInfo CodeGenWebGPU::AddFunction(const PrimFunc& f, bool skip_re os_param_access << "]"; func_launch_param_tags.push_back(os_param_access.str()); - ICHECK(!info.has_block_index_z) + TVM_FFI_ICHECK(!info.has_block_index_z) << "blockIdx.z is not supported in WebGPU to accomodate large blockIdx.x"; // anotate workgroup this->stream << "@compute @workgroup_size(" << info.workgroup_size[0] << ", " @@ -271,7 +273,7 @@ runtime::FunctionInfo CodeGenWebGPU::AddFunction(const PrimFunc& f, bool skip_re } void CodeGenWebGPU::BindThreadIndex(const IterVar& iv) { - ICHECK(!var_idmap_.count(iv->var.get())); + TVM_FFI_ICHECK(!var_idmap_.count(iv->var.get())); std::ostringstream os; PrintType(iv->var.dtype(), os); if (iv->thread_tag == "blockIdx.x") { @@ -292,7 +294,7 @@ void CodeGenWebGPU::BindThreadIndex(const IterVar& iv) { void CodeGenWebGPU::PrintType(DataType t, std::ostream& os) { // NOLINT(*) int lanes = t.lanes(); if (t.is_handle()) { - LOG(FATAL) << "Cannot print handle type in WebGPU"; + TVM_FFI_THROW(InternalError) << "Cannot print handle type in WebGPU"; } if (t.is_void()) { os << "void"; @@ -304,7 +306,8 @@ void CodeGenWebGPU::PrintType(DataType t, std::ostream& os) { // NOLINT(*) } if (lanes != 1) { - ICHECK(lanes >= 2 && lanes <= 4) << "CodeGenWebGPU: only allows vector with lanes in {2, 3, 4}"; + TVM_FFI_ICHECK(lanes >= 2 && lanes <= 4) + << "CodeGenWebGPU: only allows vector with lanes in {2, 3, 4}"; // Currently WebGPU doesn't support `i8` and an `int8x4` is represented as a `u32`. if (t.is_int() && t.bits() == 8 && lanes == 4) { os << "u32"; @@ -314,20 +317,20 @@ void CodeGenWebGPU::PrintType(DataType t, std::ostream& os) { // NOLINT(*) } if (t.is_float()) { - ICHECK(t.bits() == 16 || t.bits() == 32) << "CodeGenWebGPU: only support f16 or f32"; + TVM_FFI_ICHECK(t.bits() == 16 || t.bits() == 32) << "CodeGenWebGPU: only support f16 or f32"; if (t.bits() == 16) { // Using f16 requires enable directive enable_fp16_ = true; } os << "f" << t.bits(); } else if (t.is_uint()) { - ICHECK(t.bits() != 64) << "CodeGenWebGPU: do not support u64"; + TVM_FFI_ICHECK(t.bits() != 64) << "CodeGenWebGPU: do not support u64"; os << "u" << t.bits(); } else if (t.is_int()) { - ICHECK(t.bits() != 64) << "CodeGenWebGPU: do not support i64"; + TVM_FFI_ICHECK(t.bits() != 64) << "CodeGenWebGPU: do not support i64"; os << "i" << t.bits(); } else { - LOG(FATAL) << "CodeGenWebGPU: Cannot convert type " << t << " to WebGPU type"; + TVM_FFI_THROW(InternalError) << "CodeGenWebGPU: Cannot convert type " << t << " to WebGPU type"; } if (lanes != 1) { os << ">"; @@ -343,7 +346,7 @@ void CodeGenWebGPU::PrintStorageSync(const CallNode* op) { this->PrintIndent(); this->stream << "workgroupBarrier();\n"; } else if (sync == "global") { - LOG(FATAL) << "global barrier not supported"; + TVM_FFI_THROW(InternalError) << "global barrier not supported"; } } @@ -455,7 +458,7 @@ void CodeGenWebGPU::VisitExpr_(const LetNode* op, std::ostream& os) { // NOLINT // use ssa form. if (print_ssa_form_) { std::string value = PrintExpr(op->value); - ICHECK(!var_idmap_.count(op->var.get())); + TVM_FFI_ICHECK(!var_idmap_.count(op->var.get())); var_idmap_[op->var.get()] = value; } else { PrintIndent(); @@ -469,7 +472,7 @@ void CodeGenWebGPU::VisitExpr_(const LetNode* op, std::ostream& os) { // NOLINT // We do this because it is hard to completely avoid a same LetNode appearing // at different places. bool removed = var_idmap_.erase(op->var.get()); - ICHECK(removed); + TVM_FFI_ICHECK(removed); } void CodeGenWebGPU::VisitExpr_(const IntImmNode* op, std::ostream& os) { // NOLINT(*) @@ -478,7 +481,7 @@ void CodeGenWebGPU::VisitExpr_(const IntImmNode* op, std::ostream& os) { // NOL if (op->dtype.is_int()) { temp << op->value << "i"; } else { - ICHECK(op->dtype.is_uint()); + TVM_FFI_ICHECK(op->dtype.is_uint()); temp << op->value << "u"; } this->MarkConst(temp.str()); @@ -499,7 +502,7 @@ void CodeGenWebGPU::VisitExpr_(const FloatImmNode* op, std::ostream& os) { // N enable_fp16_ = true; temp << 'h'; } else { - LOG(FATAL) << "Unsupported floating point bits " << op->dtype.bits(); + TVM_FFI_THROW(InternalError) << "Unsupported floating point bits " << op->dtype.bits(); } MarkConst(temp.str()); os << temp.str(); @@ -510,8 +513,8 @@ void CodeGenWebGPU::VisitExpr_(const BufferLoadNode* op, std::ostream& os) { // // Each printing stmt must stand on their own after all preprocessing steps // to ensure correctness in the case of nested-expression // do not try to lift common printings from each case - ICHECK_EQ(op->indices.size(), 1) << "Load from non-flat memory not supported."; - ICHECK(!op->predicate.defined()) << "Predicated buffer load is not supported."; + TVM_FFI_ICHECK_EQ(op->indices.size(), 1) << "Load from non-flat memory not supported."; + TVM_FFI_ICHECK(!op->predicate.defined()) << "Predicated buffer load is not supported."; DataType value_dtype = op->dtype; PrimExpr index = op->indices[0]; @@ -528,9 +531,9 @@ void CodeGenWebGPU::VisitExpr_(const BufferLoadNode* op, std::ostream& os) { // this->PrintType(value_dtype, os); os << "("; } else { - ICHECK(value_dtype == element_dtype); + TVM_FFI_ICHECK(value_dtype == element_dtype); } - ICHECK_EQ(index.dtype().lanes(), 1); + TVM_FFI_ICHECK_EQ(index.dtype().lanes(), 1); os << buffer_vid << "[" << this->PrintExpr(index) << "]"; // Special handle bool loading if (value_dtype == DataType::Bool()) { @@ -538,8 +541,8 @@ void CodeGenWebGPU::VisitExpr_(const BufferLoadNode* op, std::ostream& os) { // } } else { // Vector load from scalar buffer - ICHECK_EQ(element_dtype.lanes(), 1) << "Can only vector load scalar array"; - ICHECK(value_dtype.element_of() == element_dtype) + TVM_FFI_ICHECK_EQ(element_dtype.lanes(), 1) << "Can only vector load scalar array"; + TVM_FFI_ICHECK(value_dtype.element_of() == element_dtype) << "WebGPU vector loading requires base type to match"; arith::PVar base; if (arith::ramp(base, 1, op->dtype.lanes()).Match(index)) { @@ -570,7 +573,7 @@ void CodeGenWebGPU::VisitStmt_(const LetStmtNode* op) { // use ssa form. if (print_ssa_form_) { std::string value = PrintExpr(op->value); - ICHECK(!var_idmap_.count(op->var.get())); + TVM_FFI_ICHECK(!var_idmap_.count(op->var.get())); var_idmap_[op->var.get()] = value; } else { PrintIndent(); @@ -583,8 +586,8 @@ void CodeGenWebGPU::VisitStmt_(const LetStmtNode* op) { } void CodeGenWebGPU::VisitStmt_(const BufferStoreNode* op) { - CHECK_EQ(op->indices.size(), 1) << "Store to non-flat memory not supported."; - ICHECK(!op->predicate.defined()) << "Predicated buffer store is not supported."; + TVM_FFI_ICHECK_EQ(op->indices.size(), 1) << "Store to non-flat memory not supported."; + TVM_FFI_ICHECK(!op->predicate.defined()) << "Predicated buffer store is not supported."; DataType value_dtype = op->value.dtype(); DataType element_dtype = op->buffer->dtype; @@ -606,7 +609,7 @@ void CodeGenWebGPU::VisitStmt_(const BufferStoreNode* op) { PrintType(element_dtype, stream); stream << "("; } else { - ICHECK(value_dtype == element_dtype); + TVM_FFI_ICHECK(value_dtype == element_dtype); } stream << value_vid; // Special handle bool store @@ -616,8 +619,8 @@ void CodeGenWebGPU::VisitStmt_(const BufferStoreNode* op) { stream << ";\n"; } else { // Vector store into scalar buffer - ICHECK_EQ(element_dtype.lanes(), 1) << "Can only vector load scalar array"; - ICHECK(value_dtype.element_of() == element_dtype) + TVM_FFI_ICHECK_EQ(element_dtype.lanes(), 1) << "Can only vector load scalar array"; + TVM_FFI_ICHECK(value_dtype.element_of() == element_dtype) << "WebGPU vector stire requires base type to match"; std::string value_vid = PrintExpr(op->value); arith::PVar base; @@ -644,10 +647,10 @@ void CodeGenWebGPU::VisitStmt_(const BufferStoreNode* op) { } void CodeGenWebGPU::VisitStmt_(const AllocateNode* op) { - ICHECK(!is_zero(op->condition)); + TVM_FFI_ICHECK(!is_zero(op->condition)); std::string vid = AllocVarID(op->buffer_var.get()); size_t constant_size = op->ConstantAllocationSize(); - ICHECK_GT(constant_size, 0) << "Can only handle constant size stack allocation for now"; + TVM_FFI_ICHECK_GT(constant_size, 0) << "Can only handle constant size stack allocation for now"; auto storage_scope = runtime::StorageScope::Create(GetPtrStorageScope(op->buffer_var)); if (storage_scope.rank == runtime::StorageRank::kShared) { @@ -665,7 +668,8 @@ void CodeGenWebGPU::VisitStmt_(const AllocateNode* op) { PrintType(op->dtype, this->stream); this->stream << ", " << constant_size << ">;\n"; } else { - LOG(FATAL) << "WebGPU: Do not support storage scope: " << storage_scope.to_string(); + TVM_FFI_THROW(InternalError) << "WebGPU: Do not support storage scope: " + << storage_scope.to_string(); } this->PrintStmt(op->body); } @@ -725,7 +729,8 @@ class WebGPUSourceModuleNode final : public ffi::ModuleObj { int GetPropertyMask() const final { return ffi::Module::kBinarySerializable; } ffi::Optional GetFunction(const ffi::String& name) final { - LOG(FATAL) << "WebGPUSourceModule is not directly runnable, export and run through tvmjs"; + TVM_FFI_THROW(InternalError) + << "WebGPUSourceModule is not directly runnable, export and run through tvmjs"; return std::nullopt; } @@ -776,13 +781,14 @@ ffi::Module BuildWebGPU(IRModule mod, Target target) { for (auto kv : mod->functions) { CodeGenWebGPU cg(target); - ICHECK(kv.second->IsInstance()) << "CodeGenWebGPU: Can only take PrimFunc"; + TVM_FFI_ICHECK(kv.second->IsInstance()) + << "CodeGenWebGPU: Can only take PrimFunc"; auto f = Downcast(kv.second); auto calling_conv = f->GetAttr(tvm::attr::kCallingConv); - ICHECK(calling_conv == CallingConv::kDeviceKernelLaunch) + TVM_FFI_ICHECK(calling_conv == CallingConv::kDeviceKernelLaunch) << "CodeGenWebGPU: expect calling_conv equals CallingConv::kDeviceKernelLaunch"; auto global_symbol = f->GetAttr(tvm::attr::kGlobalSymbol); - ICHECK(global_symbol.has_value()) + TVM_FFI_ICHECK(global_symbol.has_value()) << "CodeGenWebGPU: Expect PrimFunc to have the global_symbol attribute"; std::string f_name = global_symbol.value(); cg.Init(output_ssa); diff --git a/src/target/source/intrin_rule_cuda.cc b/src/target/source/intrin_rule_cuda.cc index 56b575cc6c38..4f041ff96e67 100644 --- a/src/target/source/intrin_rule_cuda.cc +++ b/src/target/source/intrin_rule_cuda.cc @@ -128,7 +128,7 @@ struct CUDAWarpIntrinsic { } else if (orig_op.same_as(builtin::tvm_warp_shuffle_up())) { return Op::Get("tir.cuda.__shfl_up_sync"); } else { - ICHECK(orig_op.same_as(builtin::tvm_warp_shuffle_down())); + TVM_FFI_ICHECK(orig_op.same_as(builtin::tvm_warp_shuffle_down())); return Op::Get("tir.cuda.__shfl_down_sync"); } } @@ -142,8 +142,8 @@ static PrimExpr DispatchCUDAWarpActiveMask(const PrimExpr& e) { template static PrimExpr DispatchCUDAShuffle(const PrimExpr& e) { const CallNode* call = e.as(); - ICHECK(call != nullptr); - ICHECK_EQ(call->args.size(), 5); // mask, value, warp_id, width, warp_size + TVM_FFI_ICHECK(call != nullptr); + TVM_FFI_ICHECK_EQ(call->args.size(), 5); // mask, value, warp_id, width, warp_size ffi::Array cuda_args{{call->args[0], call->args[1], call->args[2], call->args[3]}}; return Call(call->dtype, T()(call->dtype, Downcast(call->op)), cuda_args); } diff --git a/src/target/source/intrin_rule_metal.cc b/src/target/source/intrin_rule_metal.cc index e74c63a79ba3..be888d47fb98 100644 --- a/src/target/source/intrin_rule_metal.cc +++ b/src/target/source/intrin_rule_metal.cc @@ -37,7 +37,7 @@ struct MetalWarpIntrinsic { } else if (orig_op.same_as(builtin::tvm_warp_shuffle_up())) { return Op::Get("tir.metal.simd_shuffle_up"); } else { - ICHECK(orig_op.same_as(builtin::tvm_warp_shuffle_down())); + TVM_FFI_ICHECK(orig_op.same_as(builtin::tvm_warp_shuffle_down())); return Op::Get("tir.metal.simd_shuffle_down"); } } @@ -46,8 +46,8 @@ struct MetalWarpIntrinsic { template static PrimExpr DispatchMetalShuffle(const PrimExpr& e) { const CallNode* call = e.as(); - ICHECK(call != nullptr); - ICHECK_EQ(call->args.size(), 5); // mask, value, warp_id, width, warp_size + TVM_FFI_ICHECK(call != nullptr); + TVM_FFI_ICHECK_EQ(call->args.size(), 5); // mask, value, warp_id, width, warp_size ffi::Array metal_args{{call->args[1], call->args[2]}}; return Call(call->dtype, T()(call->dtype, Downcast(call->op)), metal_args); } diff --git a/src/target/source/intrin_rule_opencl.cc b/src/target/source/intrin_rule_opencl.cc index ea3a1c58bc3f..01c1e038cd1f 100644 --- a/src/target/source/intrin_rule_opencl.cc +++ b/src/target/source/intrin_rule_opencl.cc @@ -104,10 +104,10 @@ TVM_REGISTER_OP("tir.cosh") // When shuffle is used, we assume it is intel's shuffle extension static PrimExpr DispatchIntelShuffle(const PrimExpr& e) { const CallNode* call = e.as(); - ICHECK(call != nullptr); - ICHECK_EQ(call->args.size(), 5); // mask, value, warp_id, width, warp_size + TVM_FFI_ICHECK(call != nullptr); + TVM_FFI_ICHECK_EQ(call->args.size(), 5); // mask, value, warp_id, width, warp_size arith::Analyzer analyzer; - ICHECK(analyzer.CanProve(call->args[3] == call->args[4])) + TVM_FFI_ICHECK(analyzer.CanProve(call->args[3] == call->args[4])) << "Intel warp shuffle dose not support width != warp_size"; ffi::Array opencl_args{ {StringImm("intel_sub_group_shuffle"), call->args[1], call->args[2]}}; diff --git a/src/target/source/ptx.cc b/src/target/source/ptx.cc index c9c15ee0cb2e..70bc8557bf4e 100644 --- a/src/target/source/ptx.cc +++ b/src/target/source/ptx.cc @@ -127,7 +127,7 @@ inline DataType DTypeFromString(const std::string str) { } else if (str == ".b64") { return DataType::kBit64; } else { - LOG(FATAL) << "Unrecognized PTX data type " << str; + TVM_FFI_THROW(InternalError) << "Unrecognized PTX data type " << str; } } @@ -146,7 +146,7 @@ inline uint32_t DTypeBits(DataType dtype) { return num_bits[static_cast(dty */ inline std::tuple ParseMMAShape(const std::string& str) { size_t pos_m = str.find("m"), pos_n = str.find("n"), pos_k = str.find("k"); - CHECK(pos_m != str.npos && pos_n != str.npos && pos_k != str.npos) + TVM_FFI_ICHECK(pos_m != str.npos && pos_n != str.npos && pos_k != str.npos) << "Cannot parse MMA shape " << str; int m = std::stoi(str.substr(pos_m + 1, pos_n - pos_m - 1)), n = std::stoi(str.substr(pos_n + 1, pos_k - pos_n - 1)), k = std::stoi(str.substr(pos_k + 1)); @@ -167,7 +167,7 @@ LayoutType LayoutTypeFromString(const std::string& str) { } else if (str == "col") { return LayoutType::kColumnMajor; } else { - LOG(FATAL) << "Unrecognized layout type " << str; + TVM_FFI_THROW(InternalError) << "Unrecognized layout type " << str; } } @@ -264,24 +264,26 @@ void CheckMMADTypeCompatible(DataType dtype_a, DataType dtype_b, DataType dtype_ case DataType::kBFloat16: case DataType::kTensorFloat32: case DataType::kFloat64: - CHECK(dtype_a == dtype_b) << ab_not_match_err_str; + TVM_FFI_ICHECK(dtype_a == dtype_b) << ab_not_match_err_str; break; case DataType::kInt4: case DataType::kUInt4: - CHECK(dtype_b == DataType::kInt4 || dtype_b == DataType::kUInt4) << ab_not_match_err_str; + TVM_FFI_ICHECK(dtype_b == DataType::kInt4 || dtype_b == DataType::kUInt4) + << ab_not_match_err_str; break; case DataType::kInt8: case DataType::kUInt8: - CHECK(dtype_b == DataType::kInt8 || dtype_b == DataType::kUInt8) << ab_not_match_err_str; + TVM_FFI_ICHECK(dtype_b == DataType::kInt8 || dtype_b == DataType::kUInt8) + << ab_not_match_err_str; break; case DataType::kFloat8_e4m3: case DataType::kFloat8_e5m2: - CHECK(dtype_b == DataType::kFloat8_e4m3 || dtype_b == DataType::kFloat8_e5m2) + TVM_FFI_ICHECK(dtype_b == DataType::kFloat8_e4m3 || dtype_b == DataType::kFloat8_e5m2) << ab_not_match_err_str; break; default: - CHECK(false) << "Invalid multiplicand data types: " << DTypeToString(dtype_a) - << DTypeToString(dtype_b); + TVM_FFI_ICHECK(false) << "Invalid multiplicand data types: " << DTypeToString(dtype_a) + << DTypeToString(dtype_b); } // check a,b and c switch (dtype_a) { @@ -290,31 +292,32 @@ void CheckMMADTypeCompatible(DataType dtype_a, DataType dtype_b, DataType dtype_ case DataType::kUInt4: case DataType::kInt8: case DataType::kUInt8: - CHECK(dtype_c == DataType::kInt32) + TVM_FFI_ICHECK(dtype_c == DataType::kInt32) << "For multiplicand data type " << DTypeToString(dtype_a) << DTypeToString(dtype_b) << ", accumulator data type should be s32."; break; case DataType::kFloat16: - CHECK(dtype_c == DataType::kFloat16 || dtype_c == DataType::kFloat32) + TVM_FFI_ICHECK(dtype_c == DataType::kFloat16 || dtype_c == DataType::kFloat32) << "For multiplicand data type f16, accumulator data type should be f16/f32."; break; case DataType::kBFloat16: case DataType::kTensorFloat32: - CHECK(dtype_c == DataType::kFloat32) + TVM_FFI_ICHECK(dtype_c == DataType::kFloat32) << "For multiplicand data type bf16/tf32, accumulator data type can only be f32."; break; case DataType::kFloat64: - CHECK(dtype_c == DataType::kFloat64) + TVM_FFI_ICHECK(dtype_c == DataType::kFloat64) << "For multiplicand data type f64, accumulator data type can only be f64."; break; case DataType::kFloat8_e4m3: case DataType::kFloat8_e5m2: - CHECK(dtype_c == DataType::kFloat32) + TVM_FFI_ICHECK(dtype_c == DataType::kFloat32) << "For multiplicand data type e4m3/e5m2, accumulator data type can only be f32."; break; default: - CHECK(false) << "Invalid multiplicand/accumulator data types: " << DTypeToString(dtype_a) - << DTypeToString(dtype_b) << DTypeToString(dtype_c) << "."; + TVM_FFI_ICHECK(false) << "Invalid multiplicand/accumulator data types: " + << DTypeToString(dtype_a) << DTypeToString(dtype_b) + << DTypeToString(dtype_c) << "."; } } @@ -336,22 +339,23 @@ void CheckMMADTypeCompatible(DataType dtype_a, DataType dtype_b, DataType dtype_ void CheckMMAConfigValidity(int m, int n, int k, LayoutType layout_a, LayoutType layout_b, DataType dtype_a, DataType dtype_b, DataType dtype_c, const std::string& bit_op, bool sparse, bool saturate) { - CHECK(bit_op == "xor" || bit_op == "and" || bit_op == "") + TVM_FFI_ICHECK(bit_op == "xor" || bit_op == "and" || bit_op == "") << "Unrecognized 1-bit operation " << bit_op << " , can only be xor/and."; bool use_bit_op = !bit_op.empty(); if (use_bit_op) { - CHECK(dtype_a == DataType::kBit1) << "Bit operator is only compatible with 1-bit multiplicand."; + TVM_FFI_ICHECK(dtype_a == DataType::kBit1) + << "Bit operator is only compatible with 1-bit multiplicand."; } CheckMMADTypeCompatible(dtype_a, dtype_b, dtype_c); if (saturate) { - CHECK(dtype_a == DataType::kInt4 || dtype_a == DataType::kUInt4 || dtype_a == DataType::kInt8 || - dtype_a == DataType::kUInt8) + TVM_FFI_ICHECK(dtype_a == DataType::kInt4 || dtype_a == DataType::kUInt4 || + dtype_a == DataType::kInt8 || dtype_a == DataType::kUInt8) << "Output saturation only applicable to multiplicand type s4/u4/s8/u8."; } if (!(m == 8 && n == 8 && k == 4 && dtype_a == ptx::DataType::kFloat16)) { // Only MMA on m8n8k4 for fp16 supports customized layouts. - CHECK(layout_a == LayoutType::kRowMajor && layout_b == LayoutType::kColumnMajor) + TVM_FFI_ICHECK(layout_a == LayoutType::kRowMajor && layout_b == LayoutType::kColumnMajor) << "Invalid layout combination " << LayoutTypeToString(layout_a) << "," << LayoutTypeToString(layout_b) << "."; } @@ -364,7 +368,7 @@ void CheckMMAConfigValidity(int m, int n, int k, LayoutType layout_a, LayoutType break; } } - CHECK(match) << "Cannot find matched MMA configurations."; + TVM_FFI_ICHECK(match) << "Cannot find matched MMA configurations."; } /*! @@ -406,7 +410,7 @@ inline FragAttrs GetFragAttrs(DataType dtype) { case DataType::kFloat64: return FragAttrs('d', 64, "(double *)"); default: - ICHECK(false) << DTypeToString(dtype) << " is not matrix data type in MMA."; + TVM_FFI_ICHECK(false) << DTypeToString(dtype) << " is not matrix data type in MMA."; return FragAttrs('\0', 0, ""); } } @@ -621,9 +625,11 @@ std::string PrintLoadMatrixAssembly(bool trans, int num, const std::string& type const std::string& local_elem_offset, const std::string& smem_ptr, const std::string& smem_elem_offset) { - CHECK(num == 1 || num == 2 || num == 4) << "ldmatrix only accept loading 1/2/4 matrices."; + TVM_FFI_ICHECK(num == 1 || num == 2 || num == 4) + << "ldmatrix only accept loading 1/2/4 matrices."; ptx::DataType data_type = ptx::DTypeFromString(type); - CHECK(data_type == ptx::DataType::kBit16) << "ldmatrix only accept matrix with type .b16."; + TVM_FFI_ICHECK(data_type == ptx::DataType::kBit16) + << "ldmatrix only accept matrix with type .b16."; std::string asm_code = R"( { unsigned int addr = cast_smem_ptr_to_int({smem_addr}); @@ -682,8 +688,8 @@ std::string PrintPredicatedCpAsyncAssembly(const std::string& shared_ptr, const std::string& global_elem_offset, const std::string& bytes, const std::string& predicate_value) { - CHECK(bytes == "16" || bytes == "12" || bytes == "8" || bytes == "4" || bytes == "2" || - bytes == "1") + TVM_FFI_ICHECK(bytes == "16" || bytes == "12" || bytes == "8" || bytes == "4" || bytes == "2" || + bytes == "1") << "Only support 16, 12, 8, 4, 2, 1 bytes for predicated cp.async"; std::string predicated_asm_code = R"( { diff --git a/src/target/source/source_module.cc b/src/target/source/source_module.cc index 37b539d0139c..8edba9acc593 100644 --- a/src/target/source/source_module.cc +++ b/src/target/source/source_module.cc @@ -57,8 +57,8 @@ class SourceModuleNode : public ffi::ModuleObj { const char* kind() const final { return "source"; } ffi::Optional GetFunction(const ffi::String& name) final { - LOG(FATAL) << "Source module cannot execute, to get executable module" - << " build TVM with \'" << fmt_ << "\' runtime support"; + TVM_FFI_THROW(InternalError) << "Source module cannot execute, to get executable module" + << " build TVM with \'" << fmt_ << "\' runtime support"; } ffi::String InspectSource(const ffi::String& format) const final { return code_; } @@ -130,12 +130,12 @@ class CSourceModuleNode : public ffi::ModuleObj { support::BytesInStream stream(bytes); std::string code, fmt; - ICHECK(stream.Read(&code)) << "Loading code failed"; - ICHECK(stream.Read(&fmt)) << "Loading format failed"; + TVM_FFI_ICHECK(stream.Read(&code)) << "Loading code failed"; + TVM_FFI_ICHECK(stream.Read(&fmt)) << "Loading format failed"; std::vector tmp_func_names, tmp_const_vars; - CHECK(stream.Read(&tmp_func_names)) << "Loading func names failed"; - CHECK(stream.Read(&tmp_const_vars)) << "Loading const vars failed"; + TVM_FFI_ICHECK(stream.Read(&tmp_func_names)) << "Loading func names failed"; + TVM_FFI_ICHECK(stream.Read(&tmp_const_vars)) << "Loading const vars failed"; ffi::Array func_names; for (auto func_name : tmp_func_names) func_names.push_back(ffi::String(func_name)); @@ -151,10 +151,10 @@ class CSourceModuleNode : public ffi::ModuleObj { std::string fmt = GetFileFormat(file_name, format); std::string meta_file = GetMetaFilePath(file_name); if (fmt == "c" || fmt == "cc" || fmt == "cpp" || fmt == "cu") { - ICHECK_NE(code_.length(), 0); + TVM_FFI_ICHECK_NE(code_.length(), 0); SaveBinaryToFile(file_name, code_); } else { - ICHECK_EQ(fmt, fmt_) << "Can only save to format=" << fmt_; + TVM_FFI_ICHECK_EQ(fmt, fmt_) << "Can only save to format=" << fmt_; } } @@ -211,8 +211,8 @@ class DeviceSourceModuleNode final : public ffi::ModuleObj { : data_(data), fmt_(fmt), fmap_(fmap), type_key_(type_key), fget_source_(fget_source) {} ffi::Optional GetFunction(const ffi::String& name) final { - LOG(FATAL) << "Source module cannot execute, to get executable module" - << " build TVM with \'" << fmt_ << "\' runtime support"; + TVM_FFI_THROW(InternalError) << "Source module cannot execute, to get executable module" + << " build TVM with \'" << fmt_ << "\' runtime support"; } ffi::String InspectSource(const ffi::String& format) const final { @@ -229,7 +229,7 @@ class DeviceSourceModuleNode final : public ffi::ModuleObj { void WriteToFile(const ffi::String& file_name, const ffi::String& format) const final { std::string fmt = GetFileFormat(file_name, format); - ICHECK_EQ(fmt, fmt_) << "Can only save to format=" << fmt_; + TVM_FFI_ICHECK_EQ(fmt, fmt_) << "Can only save to format=" << fmt_; std::string meta_file = GetMetaFilePath(file_name); SaveMetaDataToFile(meta_file, fmap_); SaveBinaryToFile(file_name, data_); diff --git a/src/target/spirv/codegen_spirv.cc b/src/target/spirv/codegen_spirv.cc index 1c0de5d06484..ff8053a3096b 100644 --- a/src/target/spirv/codegen_spirv.cc +++ b/src/target/spirv/codegen_spirv.cc @@ -40,7 +40,8 @@ CodeGenSPIRV::CodeGenSPIRV(Target target) : spirv_support_(target) {} runtime::SPIRVShader CodeGenSPIRV::BuildFunction(const PrimFunc& f, const std::string& name) { this->InitFuncState(); - ICHECK(f->HasNonzeroAttr(tir::attr::kNoAlias)) << "SPIRV only takes restricted memory model"; + TVM_FFI_ICHECK(f->HasNonzeroAttr(tir::attr::kNoAlias)) + << "SPIRV only takes restricted memory model"; std::vector pod_args; uint32_t i_buffer = 0; @@ -53,13 +54,15 @@ runtime::SPIRVShader CodeGenSPIRV::BuildFunction(const PrimFunc& f, const std::s DataType t = arg.dtype(); if (t.is_handle()) { auto* ptr = arg->type_annotation.as(); - ICHECK(ptr) << "All handles passed to the Vulkan codegen must have a type_annotation as a " - "PointerType, " - << "and must point to a PrimType"; + TVM_FFI_ICHECK(ptr) + << "All handles passed to the Vulkan codegen must have a type_annotation as a " + "PointerType, " + << "and must point to a PrimType"; auto* prim = ptr->element_type.as(); - ICHECK(prim) << "All handles passed to the Vulkan codegen must have a type_annotation as a " - "PointerType, " - << "and must point to a PrimType"; + TVM_FFI_ICHECK(prim) + << "All handles passed to the Vulkan codegen must have a type_annotation as a " + "PointerType, " + << "and must point to a PrimType"; DataType value_storage_type = prim->dtype; if (value_storage_type == DataType::Bool()) { // We need a physically addressable buffer type to support boolean tensors. @@ -109,7 +112,7 @@ runtime::SPIRVShader CodeGenSPIRV::BuildFunction(const PrimFunc& f, const std::s builder_->CommitKernelFunction(func_ptr, name); - ICHECK_LE(shared_memory_bytes_used_, spirv_support_.max_shared_memory_per_block) + TVM_FFI_ICHECK_LE(shared_memory_bytes_used_, spirv_support_.max_shared_memory_per_block) << "Vulkan shader " << name << " uses " << shared_memory_bytes_used_ << " bytes of shared memory, " << "but target supports only " << spirv_support_.max_shared_memory_per_block << " bytes. " @@ -138,10 +141,10 @@ spirv::Value CodeGenSPIRV::GetThreadIndex(const IterVar& iv, const PrimExpr& ext if (ts.rank == 1) { v = builder_->GetLocalID(ts.dim_index); auto* sizeptr = extent.as(); - ICHECK(sizeptr) << "SPIRV only allows constant thread group size " - << " get " << extent; - ICHECK_GE(ts.dim_index, 0) << "vthread should have been optimized out by here"; - ICHECK_LT(ts.dim_index, 3); + TVM_FFI_ICHECK(sizeptr) << "SPIRV only allows constant thread group size " + << " get " << extent; + TVM_FFI_ICHECK_GE(ts.dim_index, 0) << "vthread should have been optimized out by here"; + TVM_FFI_ICHECK_LT(ts.dim_index, 3); workgroup_size_[ts.dim_index] = static_cast(sizeptr->value); } else { v = builder_->GetWorkgroupID(ts.dim_index); @@ -172,7 +175,7 @@ spirv::Value CodeGenSPIRV::CreateStorageSync(const CallNode* op) { sync_scope = spv::ScopeWorkgroup; memory_semantics |= spv::MemorySemanticsWorkgroupMemoryMask; } else { - LOG(FATAL) << "Do not support sync " << sync; + TVM_FFI_THROW(InternalError) << "Do not support sync " << sync; } auto type_int = builder_->GetSType(DataType::Int(32)); @@ -185,7 +188,7 @@ spirv::Value CodeGenSPIRV::CreateStorageSync(const CallNode* op) { spirv::Value CodeGenSPIRV::VisitExpr_(const VarNode* op) { auto it = var_map_.find(op); - ICHECK(it != var_map_.end()) << "cannot find variable " << op->name_hint; + TVM_FFI_ICHECK(it != var_map_.end()) << "cannot find variable " << op->name_hint; return it->second; } @@ -198,7 +201,7 @@ spirv::Value CodeGenSPIRV::VisitExpr_(const FloatImmNode* op) { } spirv::Value CodeGenSPIRV::VisitExpr_(const StringImmNode* op) { - LOG(FATAL) << "StringImm is not supported in Device code"; + TVM_FFI_THROW(InternalError) << "StringImm is not supported in Device code"; } spirv::Value CodeGenSPIRV::VisitExpr_(const CastNode* op) { @@ -286,7 +289,7 @@ spirv::Value CodeGenSPIRV::VisitExpr_(const SelectNode* op) { spirv::Value CodeGenSPIRV::VisitExpr_(const LetNode* op) { auto it = let_binding_.find(op->var); if (it != let_binding_.end()) { - ICHECK(deep_equal_(it->second->value, op->value)) + TVM_FFI_ICHECK(deep_equal_(it->second->value, op->value)) << "Let cannot bind the same var to two different values"; } else { let_binding_[op->var] = op; @@ -298,7 +301,7 @@ spirv::Value CodeGenSPIRV::VisitExpr_(const LetNode* op) { spirv::Value CodeGenSPIRV::VisitExpr_(const CallNode* op) { if (op->op.same_as(builtin::call_spirv_pure_glsl450())) { - ICHECK_GE(op->args.size(), 2U); + TVM_FFI_ICHECK_GE(op->args.size(), 2U); uint32_t inst_id = static_cast(op->args[0].as()->value); std::vector values; for (size_t i = 1; i < op->args.size(); ++i) { @@ -306,31 +309,31 @@ spirv::Value CodeGenSPIRV::VisitExpr_(const CallNode* op) { } return builder_->CallGLSL450(builder_->GetSType(op->dtype), inst_id, values); } else if (op->op.same_as(builtin::bitwise_and())) { - ICHECK_EQ(op->args.size(), 2U); + TVM_FFI_ICHECK_EQ(op->args.size(), 2U); spirv::Value a = MakeValue(op->args[0]); spirv::Value b = MakeValue(op->args[1]); return builder_->MakeValue(spv::OpBitwiseAnd, a.stype, a, b); } else if (op->op.same_as(builtin::bitwise_xor())) { - ICHECK_EQ(op->args.size(), 2U); + TVM_FFI_ICHECK_EQ(op->args.size(), 2U); spirv::Value a = MakeValue(op->args[0]); spirv::Value b = MakeValue(op->args[1]); return builder_->MakeValue(spv::OpBitwiseXor, a.stype, a, b); } else if (op->op.same_as(builtin::bitwise_or())) { - ICHECK_EQ(op->args.size(), 2U); + TVM_FFI_ICHECK_EQ(op->args.size(), 2U); spirv::Value a = MakeValue(op->args[0]); spirv::Value b = MakeValue(op->args[1]); return builder_->MakeValue(spv::OpBitwiseOr, a.stype, a, b); } else if (op->op.same_as(builtin::bitwise_not())) { - ICHECK_EQ(op->args.size(), 1U); + TVM_FFI_ICHECK_EQ(op->args.size(), 1U); spirv::Value a = MakeValue(op->args[0]); return builder_->MakeValue(spv::OpNot, a.stype, a); } else if (op->op.same_as(builtin::shift_left())) { - ICHECK_EQ(op->args.size(), 2U); + TVM_FFI_ICHECK_EQ(op->args.size(), 2U); spirv::Value a = MakeValue(op->args[0]); spirv::Value b = MakeValue(op->args[1]); return builder_->MakeValue(spv::OpShiftLeftLogical, a.stype, a, b); } else if (op->op.same_as(builtin::shift_right())) { - ICHECK_EQ(op->args.size(), 2U); + TVM_FFI_ICHECK_EQ(op->args.size(), 2U); spirv::Value a = MakeValue(op->args[0]); spirv::Value b = MakeValue(op->args[1]); if (op->args[0].dtype().is_int()) { @@ -342,7 +345,7 @@ spirv::Value CodeGenSPIRV::VisitExpr_(const CallNode* op) { return builder_->MakeValue(spv::OpBitcast, builder_->GetSType(op->dtype), MakeValue(op->args[0])); } else if (op->op.same_as(builtin::large_uint_imm())) { - ICHECK_EQ(op->args.size(), 2U); + TVM_FFI_ICHECK_EQ(op->args.size(), 2U); uint64_t low = static_cast(Downcast(op->args[0])->value); uint64_t high = static_cast(Downcast(op->args[1])->value); uint64_t val = (high << 32U) | low; @@ -350,7 +353,7 @@ spirv::Value CodeGenSPIRV::VisitExpr_(const CallNode* op) { } else if (op->op.same_as(builtin::tvm_storage_sync())) { return this->CreateStorageSync(op); } else if (op->op.same_as(builtin::if_then_else())) { - ICHECK_EQ(op->args.size(), 3U); + TVM_FFI_ICHECK_EQ(op->args.size(), 3U); spirv::Value cond = MakeValue(op->args[0]); spirv::Label then_label = builder_->NewLabel(); spirv::Label else_label = builder_->NewLabel(); @@ -377,7 +380,7 @@ spirv::Value CodeGenSPIRV::VisitExpr_(const CallNode* op) { return builder_->MakeValue(spv::OpBitCount, builder_->GetSType(op->dtype), MakeValue(op->args[0])); } else if (op->op.same_as(builtin::call_pure_extern())) { - ICHECK_GE(op->args.size(), 1U); + TVM_FFI_ICHECK_GE(op->args.size(), 1U); const std::string& func_name = op->args[0].as()->value; if (func_name == "__dp4a") { std::vector values; @@ -386,21 +389,23 @@ spirv::Value CodeGenSPIRV::VisitExpr_(const CallNode* op) { } return builder_->CallKHRIntegerDotProduct(builder_->GetSType(op->dtype), values, op->dtype); } else { - LOG(FATAL) << "SPIR-V shader cannot make extern calls. Graph contains extern \"" - << Downcast(op->args[0]) << "\""; + TVM_FFI_THROW(InternalError) + << "SPIR-V shader cannot make extern calls. Graph contains extern \"" + << Downcast(op->args[0]) << "\""; return spirv::Value(); } } else if (op->op.same_as(builtin::call_extern())) { - ICHECK_GE(op->args.size(), 1U); - LOG(FATAL) << "SPIR-V shader cannot make extern calls. Graph contains extern \"" - << Downcast(op->args[0]) << "\""; + TVM_FFI_ICHECK_GE(op->args.size(), 1U); + TVM_FFI_THROW(InternalError) + << "SPIR-V shader cannot make extern calls. Graph contains extern \"" + << Downcast(op->args[0]) << "\""; return spirv::Value(); } else if (op->op.same_as(builtin::tvm_fill_fragment())) { - ICHECK_EQ(op->args.size(), 6U); + TVM_FFI_ICHECK_EQ(op->args.size(), 6U); const VarNode* buffer_node = op->args[0].as(); - ICHECK(buffer_node && fragment_info_.count(buffer_node)); + TVM_FFI_ICHECK(buffer_node && fragment_info_.count(buffer_node)); DataType ele_dtype = GetElementDataType(buffer_node); - ICHECK(ele_dtype.is_float()) << "Only floating point fragment accumulator is supported"; + TVM_FFI_ICHECK(ele_dtype.is_float()) << "Only floating point fragment accumulator is supported"; spirv::SType ele_stype = builder_->GetSType(ele_dtype); spirv::SType& fragment_type = fragment_info_[buffer_node].stype; double init = static_cast(Downcast(op->args[5])->value); @@ -409,15 +414,15 @@ spirv::Value CodeGenSPIRV::VisitExpr_(const CallNode* op) { spirv::SType ptr_type = builder_->GetPointerType(fragment_type, fragment_info_[buffer_node].sclass); spirv::Value index = MakeValue(prim_index); - ICHECK(var_map_.count(buffer_node)); + TVM_FFI_ICHECK(var_map_.count(buffer_node)); spirv::Value ptr = builder_->StructArrayAccess(ptr_type, var_map_[buffer_node], index); builder_->MakeInst(spv::OpStore, ptr, init_val, spv::MemoryAccessMaskNone); return spirv::Value(); } else if (op->op.same_as(builtin::tvm_load_matrix_sync())) { - ICHECK_EQ(op->args.size(), 8U); + TVM_FFI_ICHECK_EQ(op->args.size(), 8U); const VarNode* buffer_node = op->args[0].as(); - ICHECK(buffer_node && fragment_info_.count(buffer_node)); + TVM_FFI_ICHECK(buffer_node && fragment_info_.count(buffer_node)); spirv::SType& fragment_type = fragment_info_[buffer_node].stype; PrimExpr dst_index = op->args[4]; PrimExpr src_ptr_expr = op->args[5]; @@ -476,7 +481,7 @@ spirv::Value CodeGenSPIRV::VisitExpr_(const CallNode* op) { builder_->MakeInst(spv::OpStore, ptr_d, result, spv::MemoryAccessMaskNone); return spirv::Value(); } else if (op->op.same_as(builtin::tvm_store_matrix_sync())) { - ICHECK_EQ(op->args.size(), 8U); + TVM_FFI_ICHECK_EQ(op->args.size(), 8U); const VarNode* buffer_node = op->args[0].as(); PrimExpr index = op->args[4]; PrimExpr buffer_ptr = op->args[5]; @@ -507,12 +512,12 @@ spirv::Value CodeGenSPIRV::VisitExpr_(const CallNode* op) { spirv::SType ele_stype = builder_->GetSType(ele_dtype); spirv::Value buffer_val = MakeValue(buffer_var); spirv::SType ptr_type = builder_->GetPointerType(ele_stype, buffer_val.stype.storage_class); - ICHECK(var_map_.count(buffer_node)); + TVM_FFI_ICHECK(var_map_.count(buffer_node)); return builder_->StructArrayAccess(ptr_type, var_map_[buffer_node], MakeValue(index)); } else if (op->op.same_as(builtin::tvm_thread_invariant())) { return MakeValue(op->args[0]); } else { - LOG(FATAL) << "Unresolved call " << op->op; + TVM_FFI_THROW(InternalError) << "Unresolved call " << op->op; } } @@ -542,8 +547,8 @@ spirv::Value CodeGenSPIRV::VisitExpr_(const BroadcastNode* op) { } spirv::Value CodeGenSPIRV::VisitExpr_(const BufferLoadNode* op) { - ICHECK_EQ(op->indices.size(), 1) << "SPIR-V codegen expects flat memory buffers"; - ICHECK(!op->predicate.defined()) << "Predicated buffer load is not supported."; + TVM_FFI_ICHECK_EQ(op->indices.size(), 1) << "SPIR-V codegen expects flat memory buffers"; + TVM_FFI_ICHECK(!op->predicate.defined()) << "Predicated buffer load is not supported."; Var buffer_var = op->buffer->data; PrimExpr prim_index = op->indices[0]; @@ -553,7 +558,7 @@ spirv::Value CodeGenSPIRV::VisitExpr_(const BufferLoadNode* op) { } auto it = storage_info_.find(buffer_var.get()); - ICHECK(it != storage_info_.end()); + TVM_FFI_ICHECK(it != storage_info_.end()); StorageInfo& info = it->second; info.CheckContentType(desired_read_type, prim_index.dtype().lanes()); @@ -593,9 +598,10 @@ spirv::Value CodeGenSPIRV::VisitExpr_(const BufferLoadNode* op) { return builder_->Concat(values); } else { - LOG(FATAL) << "Cannot perform buffer access of buffer variable '" << buffer_var->name_hint - << "' with element type " << info.element_type << " using index of type " - << prim_index->dtype << " to produce output of type " << op->dtype; + TVM_FFI_THROW(InternalError) << "Cannot perform buffer access of buffer variable '" + << buffer_var->name_hint << "' with element type " + << info.element_type << " using index of type " + << prim_index->dtype << " to produce output of type " << op->dtype; return spirv::Value(); } } @@ -616,7 +622,7 @@ void CodeGenSPIRV::Scalarize(const PrimExpr& e, std::functionvectors.size() == 1 && op->indices.size() == 1) + TVM_FFI_ICHECK(op->vectors.size() == 1 && op->indices.size() == 1) << "SPIR-V codegen only supports shuffle " << "of one vector with one index"; spirv::Value vector = MakeValue(op->vectors[0]); @@ -627,13 +633,13 @@ spirv::Value CodeGenSPIRV::VisitExpr_(const ShuffleNode* op) { } void CodeGenSPIRV::VisitStmt_(const BufferStoreNode* op) { - ICHECK_EQ(op->indices.size(), 1) << "SPIR-V codegen expects flat memory buffers"; - ICHECK(!op->predicate.defined()) << "Predicated buffer store is not supported."; + TVM_FFI_ICHECK_EQ(op->indices.size(), 1) << "SPIR-V codegen expects flat memory buffers"; + TVM_FFI_ICHECK(!op->predicate.defined()) << "Predicated buffer store is not supported."; Var buffer_var = op->buffer->data; PrimExpr prim_index = op->indices[0]; auto it = storage_info_.find(buffer_var.get()); - ICHECK(it != storage_info_.end()); + TVM_FFI_ICHECK(it != storage_info_.end()); StorageInfo& info = it->second; info.CheckContentType(op->value.dtype(), prim_index.dtype().lanes()); @@ -650,7 +656,7 @@ void CodeGenSPIRV::VisitStmt_(const BufferStoreNode* op) { if (op->value.dtype() == info.element_type) { // Requested store of a single value. This may be a scalar store // or a vectorized store, based on the array element type. - ICHECK_EQ(info.element_type, op->value.dtype()) + TVM_FFI_ICHECK_EQ(info.element_type, op->value.dtype()) << "Vulkan only allow one type access to the same buffer"; spirv::Value index = MakeValue(prim_index); spirv::Value ptr = builder_->StructArrayAccess(ptr_type, buffer, index); @@ -667,9 +673,10 @@ void CodeGenSPIRV::VisitStmt_(const BufferStoreNode* op) { this->Scalarize(prim_index, f); } else { - LOG(FATAL) << "Cannot store value of type " << op->value.dtype() << " into buffer variable '" - << buffer_var->name_hint << "' with element type " << info.element_type - << " using index of type " << prim_index->dtype; + TVM_FFI_THROW(InternalError) << "Cannot store value of type " << op->value.dtype() + << " into buffer variable '" << buffer_var->name_hint + << "' with element type " << info.element_type + << " using index of type " << prim_index->dtype; } } @@ -792,10 +799,10 @@ void CodeGenSPIRV::VisitStmt_(const IfThenElseNode* op) { } void CodeGenSPIRV::VisitStmt_(const AllocateNode* op) { - ICHECK(!is_zero(op->condition)); - ICHECK(!op->dtype.is_handle()); + TVM_FFI_ICHECK(!is_zero(op->condition)); + TVM_FFI_ICHECK(!op->dtype.is_handle()); size_t constant_size = op->ConstantAllocationSize(); - ICHECK_GT(constant_size, 0) << "Can only handle constant size stack allocation in GPU"; + TVM_FFI_ICHECK_GT(constant_size, 0) << "Can only handle constant size stack allocation in GPU"; spirv::Value buf; const std::string scope = GetPtrStorageScope(op->buffer_var); @@ -809,12 +816,12 @@ void CodeGenSPIRV::VisitStmt_(const AllocateNode* op) { case runtime::StorageRank::kWMMAMatrixA: case runtime::StorageRank::kWMMAMatrixB: case runtime::StorageRank::kWMMAAccumulator: { - ICHECK(fragment_info_.count(var_node)); + TVM_FFI_ICHECK(fragment_info_.count(var_node)); fragment_info_[var_node].scope = scope; etype = GetFragmentSType(var_node, op->dtype); storage_class = spv::StorageClassFunction; fragment_info_[var_node].sclass = storage_class; - ICHECK(fragment_info_.count(var_node)); + TVM_FFI_ICHECK(fragment_info_.count(var_node)); const std::string& scope = fragment_info_[var_node].scope; const std::string& shape_str = fragment_info_.at(var_node).shape; std::pair dim = GetWmmaFragmentDimSize(shape_str, scope); @@ -837,16 +844,16 @@ void CodeGenSPIRV::VisitStmt_(const AllocateNode* op) { shared_memory_bytes_used_ += num_bytes; } break; default: - LOG(FATAL) << "Can only allocate shared or local memory inside kernel"; + TVM_FFI_THROW(InternalError) << "Can only allocate shared or local memory inside kernel"; } builder_->SetName(buf, op->buffer_var->name_hint); StorageInfo& info = storage_info_[op->buffer_var.get()]; - ICHECK(!info.element_type_known); + TVM_FFI_ICHECK(!info.element_type_known); info.SetContentType(op->dtype, op->buffer_var->name_hint); - ICHECK(!var_map_.count(op->buffer_var.get())); + TVM_FFI_ICHECK(!var_map_.count(op->buffer_var.get())); var_map_[op->buffer_var.get()] = buf; this->VisitStmt(op->body); } @@ -865,11 +872,11 @@ void CodeGenSPIRV::VisitStmt_(const AttrStmtNode* op) { } } else if (op->attr_key == tir::attr::volatile_scope) { const VarNode* v = op->node.as(); - ICHECK(v); + TVM_FFI_ICHECK(v); storage_info_[v].is_volatile = true; } else if (op->attr_key == tir::attr::buffer_bind_scope) { const VarNode* v = op->node.as(); - ICHECK(v); + TVM_FFI_ICHECK(v); } else if (op->attr_key == tir::attr::fragment_shape) { const VarNode* buffer = op->node.as(); const StringImmNode* shape_str = op->value.as(); @@ -884,8 +891,8 @@ void CodeGenSPIRV::VisitStmt_(const AssertStmtNode* op) { } void CodeGenSPIRV::VisitStmt_(const LetStmtNode* op) { - ICHECK(!var_map_.count(op->var.get())); - ICHECK(!op->var.dtype().is_handle()); + TVM_FFI_ICHECK(!var_map_.count(op->var.get())); + TVM_FFI_ICHECK(!op->var.dtype().is_handle()); var_map_[op->var.get()] = MakeValue(op->value); analyzer_->Bind(op->var, op->value); this->VisitStmt(op->body); @@ -900,7 +907,7 @@ void CodeGenSPIRV::VisitStmt_(const SeqStmtNode* op) { void CodeGenSPIRV::VisitStmt_(const EvaluateNode* op) { MakeValue(op->value); } spirv::SType CodeGenSPIRV::GetFragmentSType(const VarNode* buffer, const DataType& dtype) { - ICHECK(fragment_info_.count(buffer)); + TVM_FFI_ICHECK(fragment_info_.count(buffer)); const std::string& scope = fragment_info_[buffer].scope; const std::string& shape_str = fragment_info_.at(buffer).shape; std::pair dim = GetWmmaFragmentDimSize(shape_str, scope); @@ -912,7 +919,7 @@ spirv::SType CodeGenSPIRV::GetFragmentSType(const VarNode* buffer, const DataTyp DataType CodeGenSPIRV::GetElementDataType(const VarNode* buffer) { auto it = storage_info_.find(buffer); - ICHECK(it != storage_info_.end()); + TVM_FFI_ICHECK(it != storage_info_.end()); return it->second.element_type; } diff --git a/src/target/spirv/codegen_spirv.h b/src/target/spirv/codegen_spirv.h index e5fde107f452..fb4d4b3c9739 100644 --- a/src/target/spirv/codegen_spirv.h +++ b/src/target/spirv/codegen_spirv.h @@ -157,18 +157,19 @@ class CodeGenSPIRV : public ExprFunctor, * the number of lanes of the index. */ void CheckContentType(DataType type, int index_lanes = 1) const { - ICHECK(element_type_known) << "Cannot check element type of buffer " << name_hint - << " no previous element type defined"; + TVM_FFI_ICHECK(element_type_known) << "Cannot check element type of buffer " << name_hint + << " no previous element type defined"; DataType expected_type = element_type.with_lanes(index_lanes * element_type.lanes()); - ICHECK_EQ(type, expected_type) << "Attempted to access buffer " << name_hint - << " as element type " << type << " using an index of size " - << index_lanes << " when the element type is " << element_type; + TVM_FFI_ICHECK_EQ(type, expected_type) + << "Attempted to access buffer " << name_hint << " as element type " << type + << " using an index of size " << index_lanes << " when the element type is " + << element_type; } // Update content type if it hasn't been updated. void SetContentType(DataType type, std::string name_hint) { - ICHECK(!element_type_known) << "Cannot set element type of buffer " << name_hint - << " a second time."; + TVM_FFI_ICHECK(!element_type_known) + << "Cannot set element type of buffer " << name_hint << " a second time."; this->element_type = type; this->name_hint = name_hint; element_type_known = true; diff --git a/src/target/spirv/intrin_rule_spirv.cc b/src/target/spirv/intrin_rule_spirv.cc index a689a550c4aa..7415367df8b3 100644 --- a/src/target/spirv/intrin_rule_spirv.cc +++ b/src/target/spirv/intrin_rule_spirv.cc @@ -36,7 +36,7 @@ namespace spirv { template PrimExpr CallGLSLIntrin(PrimExpr e, const ffi::Array& args) { const tir::CallNode* call = e.as(); - ICHECK(call != nullptr); + TVM_FFI_ICHECK(call != nullptr); ffi::Array cargs; // intrin id. cargs.push_back(IntImm(DataType::UInt(32), id)); @@ -50,7 +50,7 @@ PrimExpr CallGLSLIntrin(PrimExpr e, const ffi::Array& args) { template PrimExpr CallGLSLIntrin(PrimExpr e) { const tir::CallNode* call = e.as(); - ICHECK(call != nullptr); + TVM_FFI_ICHECK(call != nullptr); return CallGLSLIntrin(e, call->args); } @@ -145,8 +145,8 @@ using tir::FLegalize; TVM_REGISTER_OP("tir.clz").set_attr( "vulkan.FLegalize", [](const PrimExpr& e) -> PrimExpr { const tir::CallNode* call = e.as(); - ICHECK(call != nullptr); - ICHECK_EQ(call->args.size(), 1); + TVM_FFI_ICHECK(call != nullptr); + TVM_FFI_ICHECK_EQ(call->args.size(), 1); PrimExpr arg = call->args[0]; PrimExpr msb; if (arg.dtype().bits() == 64) { @@ -160,7 +160,7 @@ TVM_REGISTER_OP("tir.clz").set_attr( } else if (arg.dtype().bits() == 32) { msb = CallGLSLIntrin(e); } else { - LOG(FATAL) << "SPIR-V clz only supports a 32 bit or 64 bit integer."; + TVM_FFI_THROW(InternalError) << "SPIR-V clz only supports a 32 bit or 64 bit integer."; } return PrimExpr(arg.dtype().bits() - 1) - msb; }); diff --git a/src/target/spirv/ir_builder.cc b/src/target/spirv/ir_builder.cc index bac66a3aacf7..135888c23d8b 100644 --- a/src/target/spirv/ir_builder.cc +++ b/src/target/spirv/ir_builder.cc @@ -34,7 +34,7 @@ namespace spirv { IRBuilder::IRBuilder(const SPIRVSupport& support) : spirv_support_(support) {} void IRBuilder::InitHeader() { - ICHECK_EQ(header_.size(), 0U); + TVM_FFI_ICHECK_EQ(header_.size(), 0U); header_.push_back(spv::MagicNumber); // Target SPIR-V version 1.0. Additional functionality will be @@ -126,7 +126,7 @@ SType IRBuilder::GetSType(const DataType& dtype, uint32_t row, uint32_t col) { type_key = static_cast(dtype.code()); type_key |= static_cast(dtype.bits()) << 8U; if (row * col == 0) { - ICHECK((row == 0) && (col == 0)); + TVM_FFI_ICHECK((row == 0) && (col == 0)); type_key |= static_cast(dtype.lanes()) << 16U; } else { type_key |= static_cast(row) << 32U; @@ -143,7 +143,7 @@ SType IRBuilder::GetSType(const DataType& dtype, uint32_t row, uint32_t col) { } SType IRBuilder::GetPointerType(const SType& value_type, spv::StorageClass storage_class) { - ICHECK_NE(storage_class, spv::StorageClassMax); + TVM_FFI_ICHECK_NE(storage_class, spv::StorageClassMax); auto key = std::make_pair(value_type.id, storage_class); auto it = pointer_type_tbl_.find(key); if (it != pointer_type_tbl_.end()) { @@ -179,7 +179,7 @@ SType IRBuilder::GetStructArrayType(const SType& value_type, uint32_t num_elems, ib_.Begin(spv::OpTypeRuntimeArray).AddSeq(arr_type, value_type).Commit(&global_); } int nbits = value_type.type.bits() * value_type.type.lanes(); - ICHECK_EQ(nbits % 8, 0); + TVM_FFI_ICHECK_EQ(nbits % 8, 0); uint32_t nbytes = static_cast(nbits) / 8; // decorate the array type. this->Decorate(spv::OpDecorate, arr_type, spv::DecorationArrayStride, nbytes); @@ -214,7 +214,7 @@ SType IRBuilder::GetStructArrayType(const SType& value_type, uint32_t num_elems, } Value IRBuilder::StructArrayAccess(const SType& res_type, Value buffer, Value index) { - ICHECK(buffer.flag == kStructArrayPtr); + TVM_FFI_ICHECK(buffer.flag == kStructArrayPtr); return MakeValue(spv::OpInBoundsAccessChain, res_type, buffer, const_i32_zero_, index); } @@ -233,7 +233,7 @@ Value IRBuilder::FloatImm(const SType& dtype, double value) { uint64_t data = ptr[0]; return GetConst_(dtype, &data); } else { - ICHECK_EQ(dtype.type.bits(), 16); + TVM_FFI_ICHECK_EQ(dtype.type.bits(), 16); float fvalue = static_cast(value); uint32_t* ptr = reinterpret_cast(&fvalue); uint64_t data = ptr[0]; @@ -283,13 +283,13 @@ Value IRBuilder::DeclareStorageVariable(const std::vector& value_types, .Commit(&decorate_); DataType t = value_types[i].type; uint32_t nbits = t.bits() * t.lanes(); - ICHECK_EQ(nbits % 8, 0); + TVM_FFI_ICHECK_EQ(nbits % 8, 0); uint32_t bytes = (nbits / 8); if (t.bits() == 32) { // In our Vulkan runtime, each scalar argument always occupies 64 bit. offset += bytes * 2; } else { - ICHECK_EQ(t.bits(), 64); + TVM_FFI_ICHECK_EQ(t.bits(), 64); offset += bytes; } } @@ -302,7 +302,7 @@ Value IRBuilder::DeclareStorageVariable(const std::vector& value_types, } Value IRBuilder::DeclarePushConstant(const std::vector& value_types) { - ICHECK_EQ(push_const_.id, 0); + TVM_FFI_ICHECK_EQ(push_const_.id, 0); return DeclareStorageVariable(value_types, spv::StorageClassPushConstant, kPushConstantPtr); } @@ -335,7 +335,7 @@ Value IRBuilder::GetUniform(Value ptr_push_const, const SType& v_type, uint32_t Value IRBuilder::NewFunction() { return NewValue(t_void_func_, kFunction); } void IRBuilder::CommitKernelFunction(const Value& func, const std::string& name) { - ICHECK_EQ(func.flag, kFunction); + TVM_FFI_ICHECK_EQ(func.flag, kFunction); ib_.Begin(spv::OpEntryPoint).AddSeq(spv::ExecutionModelGLCompute, func, name); for (auto& it : built_in_tbl_) { ib_.Add(it.second); @@ -344,7 +344,7 @@ void IRBuilder::CommitKernelFunction(const Value& func, const std::string& name) } void IRBuilder::StartFunction(const Value& func) { - ICHECK_EQ(func.flag, kFunction); + TVM_FFI_ICHECK_EQ(func.flag, kFunction); // add function declaration to the header. ib_.Begin(spv::OpFunction).AddSeq(t_void_, func, 0, t_void_func_).Commit(&func_header_); @@ -354,7 +354,7 @@ void IRBuilder::StartFunction(const Value& func) { } void IRBuilder::SetLocalSize(const Value& func, uint32_t local_size[3]) { - ICHECK_EQ(func.flag, kFunction); + TVM_FFI_ICHECK_EQ(func.flag, kFunction); ib_.Begin(spv::OpExecutionMode) .AddSeq(func, spv::ExecutionModeLocalSize, local_size[0], local_size[1], local_size[2]) .Commit(&exec_mode_); @@ -362,7 +362,7 @@ void IRBuilder::SetLocalSize(const Value& func, uint32_t local_size[3]) { Value IRBuilder::Allocate(const SType& value_type, uint32_t num_elems, spv::StorageClass storage_class) { - ICHECK_NE(num_elems, 0U); + TVM_FFI_ICHECK_NE(num_elems, 0U); SType sarr_type = GetStructArrayType(value_type, num_elems, false); SType ptr_type = GetPointerType(sarr_type, storage_class); Value val = NewValue(ptr_type, kStructArrayPtr); @@ -403,7 +403,7 @@ Value IRBuilder::GetBuiltInValue(spv::BuiltIn built_in, uint32_t index, const st break; default: - LOG(FATAL) << "No data type defined for SPIR-V Built-In " << built_in; + TVM_FFI_THROW(InternalError) << "No data type defined for SPIR-V Built-In " << built_in; } // Look up the decorated array value at global scope. If it doesn't @@ -465,7 +465,7 @@ Value IRBuilder::GetConst_(const SType& dtype, const uint64_t* pvalue) { if (it != const_tbl_.end()) { return it->second; } - ICHECK_LE(dtype.type.bits(), 64); + TVM_FFI_ICHECK_LE(dtype.type.bits(), 64); Value ret = NewValue(dtype, kConstant); if (dtype.type == DataType::Bool()) { // bool types. @@ -510,7 +510,7 @@ SType IRBuilder::DeclareType(const DataType& dtype, uint32_t row, uint32_t col) } else if (dtype.is_float()) { ib_.Begin(spv::OpTypeFloat).AddSeq(t, dtype.bits()).Commit(&global_); } else { - LOG(FATAL) << "declare type do not support handle"; + TVM_FFI_THROW(InternalError) << "declare type do not support handle"; } return t; } else { @@ -520,7 +520,7 @@ SType IRBuilder::DeclareType(const DataType& dtype, uint32_t row, uint32_t col) SType base_type = GetSType(dtype.element_of()); if (row * col == 0) { - ICHECK((row == 0) && (col == 0)); + TVM_FFI_ICHECK((row == 0) && (col == 0)); ib_.Begin(spv::OpTypeVector).AddSeq(t, base_type, dtype.lanes()).Commit(&global_); } else { Value v_row = GetSpecConst(GetSType(DataType::UInt(32)), row); @@ -538,21 +538,21 @@ void IRBuilder::AddCapabilityFor(const DataType& dtype) { // Declare appropriate capabilities for int/float types if (dtype.is_int() || dtype.is_uint()) { if (dtype.bits() == 8) { - ICHECK(spirv_support_.supports_int8) + TVM_FFI_ICHECK(spirv_support_.supports_int8) << "Vulkan target does not support Int8 capability. " << "If your device supports 8-bit int operations, " << "please either add -supports_int8=1 to the target, " << "or query all device parameters by adding -from_device=0."; capabilities_used_.insert(spv::CapabilityInt8); } else if (dtype.bits() == 16) { - ICHECK(spirv_support_.supports_int16) + TVM_FFI_ICHECK(spirv_support_.supports_int16) << "Vulkan target does not support Int16 capability. " << "If your device supports 16-bit int operations, " << "please either add -supports_int16=1 to the target, " << "or query all device parameters by adding -from_device=0."; capabilities_used_.insert(spv::CapabilityInt16); } else if (dtype.bits() == 64) { - ICHECK(spirv_support_.supports_int64) + TVM_FFI_ICHECK(spirv_support_.supports_int64) << "Vulkan target does not support Int64 capability. " << "If your device supports 64-bit int operations, " << "please either add -supports_int64=1 to the target, " @@ -562,14 +562,14 @@ void IRBuilder::AddCapabilityFor(const DataType& dtype) { } else if (dtype.is_float()) { if (dtype.bits() == 16) { - ICHECK(spirv_support_.supports_float16) + TVM_FFI_ICHECK(spirv_support_.supports_float16) << "Vulkan target does not support Float16 capability. " << "If your device supports 16-bit float operations, " << "please either add -supports_float16=1 to the target, " << "or query all device parameters by adding -from_device=0."; capabilities_used_.insert(spv::CapabilityFloat16); } else if (dtype.bits() == 64) { - ICHECK(spirv_support_.supports_float64) + TVM_FFI_ICHECK(spirv_support_.supports_float64) << "Vulkan target does not support Float64 capability. " << "If your device supports 64-bit float operations, " << "please either add -supports_float64=1 to the target, " @@ -584,7 +584,7 @@ void IRBuilder::AddCapabilityFor(const DataType& dtype) { // Int8 prevents use of an 8-bit loop iterator on a device that // supports Int8 but doesn't support 8-bit buffer access. if (dtype.bits() == 8 && !dtype.is_bool()) { - ICHECK(spirv_support_.supports_storage_buffer_8bit_access) + TVM_FFI_ICHECK(spirv_support_.supports_storage_buffer_8bit_access) << "Vulkan target does not support StorageBuffer8BitAccess. " << "If your device supports 8-bit buffer access, " << "please either add -supports_8bit_buffer=1 to the target, " @@ -592,14 +592,14 @@ void IRBuilder::AddCapabilityFor(const DataType& dtype) { capabilities_used_.insert(spv::CapabilityStorageBuffer8BitAccess); extensions_used_.insert("SPV_KHR_8bit_storage"); - ICHECK(spirv_support_.supports_storage_buffer_storage_class) + TVM_FFI_ICHECK(spirv_support_.supports_storage_buffer_storage_class) << "Illegal Vulkan target description. " << "Vulkan spec requires extension VK_KHR_storage_buffer_storage_class " << "if VK_KHR_8bit_storage is supported. " << "Please either add -supports_storage_buffer_storage_class=1 to the target, " << "or query all device parameters by adding -from_device=0."; } else if (dtype.bits() == 16) { - ICHECK(spirv_support_.supports_storage_buffer_16bit_access) + TVM_FFI_ICHECK(spirv_support_.supports_storage_buffer_16bit_access) << "Vulkan target does not support StorageBuffer16BitAccess. " << "If your device supports 16-bit buffer access, " << "please either add -supports_16bit_buffer=1 to the target, " @@ -625,7 +625,7 @@ PhiValue IRBuilder::MakePhi(const SType& out_type, uint32_t num_incoming) { phi.stype = out_type; phi.flag = kNormal; phi.instr = ib_.Commit(&function_); - ICHECK_EQ(phi.instr.WordCount(), 2 * num_incoming + 3); + TVM_FFI_ICHECK_EQ(phi.instr.WordCount(), 2 * num_incoming + 3); return phi; } @@ -643,11 +643,11 @@ Value IRBuilder::CallGLSL450(const SType& ret_type, uint32_t inst_id, Value IRBuilder::CallKHRIntegerDotProduct(const SType& ret_type, const std::vector& args, const DataType& dtype) { if (args.size() != 3) { - LOG(FATAL) << "Unresolved arguments in SPIRV_KHR_integer_dot_product"; + TVM_FFI_THROW(InternalError) << "Unresolved arguments in SPIRV_KHR_integer_dot_product"; } Value val = NewValue(ret_type, kNormal); #ifdef TVM_SPIRV_KHR_INTEGER_DOT_PRODUCT - ICHECK(spirv_support_.supports_integer_dot_product) + TVM_FFI_ICHECK(spirv_support_.supports_integer_dot_product) << "Vulkan target does not support integer dot product capability. " << "If your device supports integer dot product operations, " << "please either add -mattr=+dotprod to the target, " @@ -657,10 +657,11 @@ Value IRBuilder::CallKHRIntegerDotProduct(const SType& ret_type, const std::vect } else if (dtype.is_uint()) { ib_.Begin(spv::OpUDotAccSatKHR).AddSeq(ret_type, val); } else { - LOG(FATAL) << "Unsupported type"; + TVM_FFI_THROW(InternalError) << "Unsupported type"; } #else - LOG(FATAL) << "Please turn on USE_SPIRV_KHR_INTEGER_DOT_PRODUCT in config.cmake"; + TVM_FFI_THROW(InternalError) + << "Please turn on USE_SPIRV_KHR_INTEGER_DOT_PRODUCT in config.cmake"; #endif for (const Value& v : args) { @@ -675,7 +676,7 @@ Value IRBuilder::Concat(const std::vector& vec) { DataType etype = vec[0].stype.type; int lanes = etype.lanes(); for (size_t i = 1; i < vec.size(); ++i) { - ICHECK_EQ(etype, vec[i].stype.type.element_of()) + TVM_FFI_ICHECK_EQ(etype, vec[i].stype.type.element_of()) << "Cannot concat vector of different element type"; lanes += vec[i].stype.type.lanes(); is_const = is_const && (vec[i].flag == kConstant); @@ -700,11 +701,11 @@ Value IRBuilder::Concat(const std::vector& vec) { } Value IRBuilder::Cast(const SType& dst_type, spirv::Value value) { - ICHECK_NE(value.stype.id, 0U); + TVM_FFI_ICHECK_NE(value.stype.id, 0U); if (value.stype.id == dst_type.id) return value; const tvm::DataType& from = value.stype.type; const tvm::DataType& to = dst_type.type; - ICHECK_EQ(from.lanes(), to.lanes()); + TVM_FFI_ICHECK_EQ(from.lanes(), to.lanes()); if (from == DataType::Bool()) { if (to.is_int()) { return Select(value, IntImm(dst_type, 1), IntImm(dst_type, 0)); @@ -714,7 +715,7 @@ Value IRBuilder::Cast(const SType& dst_type, spirv::Value value) { return MakeValue(spv::OpConvertUToF, dst_type, Select(value, UIntImm(t_uint32_, 1), UIntImm(t_uint32_, 0))); } else { - LOG(FATAL) << "cannot cast from " << from << " to " << to; + TVM_FFI_THROW(InternalError) << "cannot cast from " << from << " to " << to; return Value(); } } else if (to == DataType::Bool()) { @@ -723,7 +724,7 @@ Value IRBuilder::Cast(const SType& dst_type, spirv::Value value) { } else if (to.is_uint()) { return NE(value, UIntImm(value.stype, 0)); } else { - LOG(FATAL) << "cannot cast from " << from << " to " << to; + TVM_FFI_THROW(InternalError) << "cannot cast from " << from << " to " << to; return Value(); } } else if (from.is_int() && to.is_int()) { @@ -751,7 +752,7 @@ Value IRBuilder::Cast(const SType& dst_type, spirv::Value value) { } else if (from.is_float() && to.is_float()) { return MakeValue(spv::OpFConvert, dst_type, value); } else { - LOG(FATAL) << "do not support type cast from " << from << " to " << to; + TVM_FFI_THROW(InternalError) << "do not support type cast from " << from << " to " << to; return Value(); } } @@ -772,7 +773,7 @@ Value IRBuilder::GetCompositeConst(const SType& ele_stype, const SType& composit } Value IRBuilder::GetSpecConst(const SType& dtype, uint64_t value) { - ICHECK_LE(dtype.type.bits(), 32); + TVM_FFI_ICHECK_LE(dtype.type.bits(), 32); Value ret = NewValue(dtype, kSpecConst); ib_.Begin(spv::OpSpecConstant).AddSeq(dtype, ret); ib_.Add(static_cast(value)); @@ -782,24 +783,24 @@ Value IRBuilder::GetSpecConst(const SType& dtype, uint64_t value) { #define DEFINE_BUILDER_BINARY_USIGN_OP(_OpName, _Op) \ Value IRBuilder::_OpName(Value a, Value b) { \ - ICHECK_EQ(a.stype.id, b.stype.id); \ + TVM_FFI_ICHECK_EQ(a.stype.id, b.stype.id); \ if (a.stype.type.is_int() || a.stype.type.is_uint()) { \ return MakeValue(spv::OpI##_Op, a.stype, a, b); \ } else { \ - ICHECK(a.stype.type.is_float()); \ + TVM_FFI_ICHECK(a.stype.type.is_float()); \ return MakeValue(spv::OpF##_Op, a.stype, a, b); \ } \ } #define DEFINE_BUILDER_BINARY_SIGN_OP(_OpName, _Op) \ Value IRBuilder::_OpName(Value a, Value b) { \ - ICHECK_EQ(a.stype.id, b.stype.id); \ + TVM_FFI_ICHECK_EQ(a.stype.id, b.stype.id); \ if (a.stype.type.is_int()) { \ return MakeValue(spv::OpS##_Op, a.stype, a, b); \ } else if (a.stype.type.is_uint()) { \ return MakeValue(spv::OpU##_Op, a.stype, a, b); \ } else { \ - ICHECK(a.stype.type.is_float()); \ + TVM_FFI_ICHECK(a.stype.type.is_float()); \ return MakeValue(spv::OpF##_Op, a.stype, a, b); \ } \ } @@ -810,28 +811,28 @@ DEFINE_BUILDER_BINARY_USIGN_OP(Mul, Mul); DEFINE_BUILDER_BINARY_SIGN_OP(Div, Div); Value IRBuilder::Mod(Value a, Value b) { - ICHECK_EQ(a.stype.id, b.stype.id); + TVM_FFI_ICHECK_EQ(a.stype.id, b.stype.id); if (a.stype.type.is_int()) { return MakeValue(spv::OpSRem, a.stype, a, b); } else if (a.stype.type.is_uint()) { return MakeValue(spv::OpUMod, a.stype, a, b); } else { - ICHECK(a.stype.type.is_float()); + TVM_FFI_ICHECK(a.stype.type.is_float()); return MakeValue(spv::OpFRem, a.stype, a, b); } } #define DEFINE_BUILDER_CMP_OP(_OpName, _Op) \ Value IRBuilder::_OpName(Value a, Value b) { \ - ICHECK_EQ(a.stype.id, b.stype.id); \ - ICHECK_EQ(a.stype.type.lanes(), b.stype.type.lanes()); \ + TVM_FFI_ICHECK_EQ(a.stype.id, b.stype.id); \ + TVM_FFI_ICHECK_EQ(a.stype.type.lanes(), b.stype.type.lanes()); \ const auto& bool_type = this->GetSType(DataType::Bool().with_lanes(a.stype.type.lanes())); \ if (a.stype.type.is_int()) { \ return MakeValue(spv::OpS##_Op, bool_type, a, b); \ } else if (a.stype.type.is_uint()) { \ return MakeValue(spv::OpU##_Op, bool_type, a, b); \ } else { \ - ICHECK(a.stype.type.is_float()); \ + TVM_FFI_ICHECK(a.stype.type.is_float()); \ return MakeValue(spv::OpFOrd##_Op, bool_type, a, b); \ } \ } @@ -843,13 +844,13 @@ DEFINE_BUILDER_CMP_OP(GE, GreaterThanEqual); #define DEFINE_BUILDER_CMP_UOP(_OpName, _Op) \ Value IRBuilder::_OpName(Value a, Value b) { \ - ICHECK_EQ(a.stype.id, b.stype.id); \ - ICHECK_EQ(a.stype.type.lanes(), b.stype.type.lanes()); \ + TVM_FFI_ICHECK_EQ(a.stype.id, b.stype.id); \ + TVM_FFI_ICHECK_EQ(a.stype.type.lanes(), b.stype.type.lanes()); \ const auto& bool_type = this->GetSType(DataType::Bool().with_lanes(a.stype.type.lanes())); \ if (a.stype.type.is_int() || a.stype.type.is_uint()) { \ return MakeValue(spv::OpI##_Op, bool_type, a, b); \ } else { \ - ICHECK(a.stype.type.is_float()); \ + TVM_FFI_ICHECK(a.stype.type.is_float()); \ return MakeValue(spv::OpFOrd##_Op, bool_type, a, b); \ } \ } @@ -858,8 +859,8 @@ DEFINE_BUILDER_CMP_UOP(EQ, Equal); DEFINE_BUILDER_CMP_UOP(NE, NotEqual); Value IRBuilder::Select(Value cond, Value a, Value b) { - ICHECK_EQ(a.stype.id, b.stype.id); - ICHECK_EQ(cond.stype.type.element_of(), DataType::Bool()); + TVM_FFI_ICHECK_EQ(a.stype.id, b.stype.id); + TVM_FFI_ICHECK_EQ(cond.stype.type.element_of(), DataType::Bool()); return MakeValue(spv::OpSelect, a.stype, cond, a, b); } diff --git a/src/target/spirv/ir_builder.h b/src/target/spirv/ir_builder.h index 5df779c59547..8be080406506 100644 --- a/src/target/spirv/ir_builder.h +++ b/src/target/spirv/ir_builder.h @@ -99,7 +99,7 @@ class Instr { * \return reference to idx-th word. */ uint32_t& operator[](uint32_t idx) { - ICHECK_LT(idx, word_count_); + TVM_FFI_ICHECK_LT(idx, word_count_); return (*data_)[begin_ + idx]; } @@ -128,7 +128,7 @@ struct PhiValue : public Value { * \param parent The parent label. */ void SetIncoming(uint32_t index, const Value& value, const Label& parent) { - ICHECK_EQ(this->stype.id, value.stype.id); + TVM_FFI_ICHECK_EQ(this->stype.id, value.stype.id); instr[3 + index * 2] = value.id; instr[3 + index * 2 + 1] = parent.id; } @@ -158,7 +158,7 @@ class InstrBuilder { */ InstrBuilder& Begin(spv::Op op) { // NOLINT(*); // finish previous build - ICHECK_EQ(data_.size(), 0U); + TVM_FFI_ICHECK_EQ(data_.size(), 0U); op_ = op; data_.push_back(0); return *this; diff --git a/src/target/spirv/spirv_support.cc b/src/target/spirv/spirv_support.cc index 91b45b85bbd0..5629d4f7f216 100644 --- a/src/target/spirv/spirv_support.cc +++ b/src/target/spirv/spirv_support.cc @@ -33,7 +33,7 @@ namespace codegen { SPIRVSupport::SPIRVSupport(tvm::Target target) { auto device_type = target->GetTargetDeviceType(); - ICHECK(device_type == kDLVulkan || device_type == kDLOpenCL || device_type == kDLWebGPU) + TVM_FFI_ICHECK(device_type == kDLVulkan || device_type == kDLOpenCL || device_type == kDLWebGPU) << "Unsupported device type for SPIRV codegen:" << device_type; if (target->GetAttr("vulkan_api_version")) { diff --git a/src/target/spirv/spirv_utils.cc b/src/target/spirv/spirv_utils.cc index 014ffc4aa191..1724abb52b7f 100644 --- a/src/target/spirv/spirv_utils.cc +++ b/src/target/spirv/spirv_utils.cc @@ -82,10 +82,9 @@ class SPIRVTools { SPV_BINARY_TO_TEXT_OPTION_FRIENDLY_NAMES | SPV_BINARY_TO_TEXT_OPTION_INDENT, &text, &diagnostic); - ICHECK_EQ(res, SPV_SUCCESS) << " line=" << diagnostic->position.line - << " column=" << diagnostic->position.column - << " index=" << diagnostic->position.index - << " error:" << diagnostic->error; + TVM_FFI_ICHECK_EQ(res, SPV_SUCCESS) + << " line=" << diagnostic->position.line << " column=" << diagnostic->position.column + << " index=" << diagnostic->position.index << " error:" << diagnostic->error; spvDiagnosticDestroy(diagnostic); std::string ret(text->str); @@ -99,8 +98,8 @@ class SPIRVTools { spv_diagnostic diagnostic = nullptr; spv_result_t res = spvValidate(ctx_, &spv_bin, &diagnostic); - ICHECK_EQ(res, SPV_SUCCESS) << " index=" << diagnostic->position.index - << " error:" << diagnostic->error; + TVM_FFI_ICHECK_EQ(res, SPV_SUCCESS) + << " index=" << diagnostic->position.index << " error:" << diagnostic->error; spvDiagnosticDestroy(diagnostic); } @@ -124,13 +123,13 @@ std::pair, std::string> Lo CodeGenSPIRV cg(target); for (auto kv : mod->functions) { - ICHECK(kv.second->IsInstance()) << "CodeGenSPIRV: Can only take PrimFunc"; + TVM_FFI_ICHECK(kv.second->IsInstance()) << "CodeGenSPIRV: Can only take PrimFunc"; auto f = Downcast(kv.second); auto calling_conv = f->GetAttr(tvm::attr::kCallingConv); - ICHECK(calling_conv == CallingConv::kDeviceKernelLaunch) + TVM_FFI_ICHECK(calling_conv == CallingConv::kDeviceKernelLaunch) << "CodeGenSPIRV: expect calling_conv equals CallingConv::kDeviceKernelLaunch"; auto global_symbol = f->GetAttr(tvm::attr::kGlobalSymbol); - ICHECK(global_symbol.has_value()) + TVM_FFI_ICHECK(global_symbol.has_value()) << "CodeGenSPIRV: Expect PrimFunc to have the global_symbol attribute"; std::string f_name = global_symbol.value(); @@ -161,7 +160,7 @@ std::pair, std::string> Lo arr.data = reinterpret_cast(shader.data.data()); arr.size = shader.data.size() * sizeof(uint32_t); std::string transformed = (*postproc)(&arr, target).cast(); - ICHECK_EQ(transformed.length() % 4U, 0U); + TVM_FFI_ICHECK_EQ(transformed.length() % 4U, 0U); shader.data.resize(transformed.size() / 4U); std::copy(transformed.begin(), transformed.end(), reinterpret_cast(shader.data.data())); @@ -177,7 +176,7 @@ std::pair, std::string> Lo std::pair, std::string> LowerToSPIRV( IRModule mod, Target target) { - LOG(FATAL) + TVM_FFI_THROW(InternalError) << "LowerToSPIRV is called but SPIRV codegen is not enabled. Please set -DUSE_VULKAN=ON."; return {}; } diff --git a/src/target/tag.cc b/src/target/tag.cc index d8ba94e6c683..c41d07708409 100644 --- a/src/target/tag.cc +++ b/src/target/tag.cc @@ -76,7 +76,7 @@ ffi::Map TargetTag::ListTags() { Target TargetTag::AddTag(ffi::String name, ffi::Map config, bool override) { TargetTagRegEntry& tag = TargetTagRegEntry::RegisterOrGet(name).set_name(); - ICHECK(override || tag.tag_->config.empty()) + TVM_FFI_ICHECK(override || tag.tag_->config.empty()) << "Tag \"" << name << "\" has been previously defined as: " << tag.tag_->config; tag.set_config(config); return Target(config); diff --git a/src/target/target.cc b/src/target/target.cc index 277ae36bb6c2..91a5854b3934 100644 --- a/src/target/target.cc +++ b/src/target/target.cc @@ -206,8 +206,8 @@ void Target::EnterWithScope() { void Target::ExitWithScope() { TVMTargetThreadLocalEntry* entry = TVMTargetThreadLocalStoreGet(); - ICHECK(!entry->context_stack.empty()); - ICHECK(entry->context_stack.top().same_as(*this)); + TVM_FFI_ICHECK(!entry->context_stack.empty()); + TVM_FFI_ICHECK(entry->context_stack.top().same_as(*this)); entry->context_stack.pop(); } @@ -216,7 +216,7 @@ Target Target::Current(bool allow_not_defined) { if (entry->context_stack.size() > 0) { return entry->context_stack.top(); } - ICHECK(allow_not_defined) + TVM_FFI_ICHECK(allow_not_defined) << "Target context required. Please set it by constructing a TargetContext"; return Target(); @@ -234,7 +234,7 @@ void TargetInternal::ConstructorDispatcher(ffi::PackedArgs args, ffi::Any* rv) { } else if (auto opt_map = arg.try_cast>()) { *rv = Target(opt_map.value()); } else { - LOG(FATAL) << "TypeError: Cannot create target with type: " << args[0].GetTypeKey(); + TVM_FFI_THROW(TypeError) << "Cannot create target with type: " << args[0].GetTypeKey(); } return; } else if (args.size() == 2) { @@ -243,11 +243,12 @@ void TargetInternal::ConstructorDispatcher(ffi::PackedArgs args, ffi::Any* rv) { Target host = args[1].cast(); *rv = Target(target, host); } else { - LOG(FATAL) << "ValueError: Invalid type of arguments. Expect 2 Target arguments."; + TVM_FFI_THROW(ValueError) << "Invalid type of arguments. Expect 2 Target arguments."; } return; } - LOG(FATAL) << "ValueError: Invalid number of arguments. Expect 1 or 2, but gets: " << args.size(); + TVM_FFI_THROW(ValueError) << "Invalid number of arguments. Expect 1 or 2, but gets: " + << args.size(); } ObjectPtr TargetInternal::FromString(const ffi::String& tag_or_config_or_target_str) { @@ -294,10 +295,10 @@ ObjectPtr TargetInternal::FromConfig(ffi::Map // Step 0: If "tag" is present without "kind", look up the tag config and merge overrides on top if (!config.count(kKind) && config.count(kTag)) { auto tag_name = config[kTag].try_cast(); - ICHECK(tag_name.has_value()) << "Expect type of field \"tag\" is String, but get type: " - << config[kTag].GetTypeKey(); + TVM_FFI_ICHECK(tag_name.has_value()) + << "Expect type of field \"tag\" is String, but get type: " << config[kTag].GetTypeKey(); auto tag_config = TargetTag::GetConfig(tag_name.value()); - ICHECK(tag_config.has_value()) << "Unknown target tag: " << tag_name.value(); + TVM_FFI_ICHECK(tag_config.has_value()) << "Unknown target tag: " << tag_name.value(); // Start from the tag's base config, then apply user overrides ffi::Map merged = tag_config.value(); for (const auto& kv : config) { @@ -412,9 +413,10 @@ std::unordered_map TargetInternal::QueryDevice(int device api->GetAttr(device, runtime::kExist, &ret); bool device_exists = ret.cast(); if (!device_exists) { - ICHECK(device_exists) << "Requested reading the parameters for " << target->kind->name - << " from device_id " << device_id << ", but device_id " << device_id - << " doesn't exist. Using default target parameters."; + TVM_FFI_ICHECK(device_exists) << "Requested reading the parameters for " << target->kind->name + << " from device_id " << device_id << ", but device_id " + << device_id + << " doesn't exist. Using default target parameters."; return output; } diff --git a/src/target/target_kind.cc b/src/target/target_kind.cc index b3f37e6cb653..46e36d101965 100644 --- a/src/target/target_kind.cc +++ b/src/target/target_kind.cc @@ -49,7 +49,7 @@ TVM_FFI_STATIC_INIT_BLOCK() { }) .def("__data_from_json__", [](const ffi::String& name) { auto kind = TargetKind::Get(name); - ICHECK(kind.has_value()) << "Cannot find target kind \'" << name << '\''; + TVM_FFI_ICHECK(kind.has_value()) << "Cannot find target kind \'" << name << '\''; return kind.value(); }); } @@ -150,8 +150,8 @@ void CheckOrSetAttr(ffi::Map* attrs, const ffi::String& n attrs->Set(name, value); } else { auto str = (*iter).second.try_cast(); - ICHECK(str && str.value() == value) << "ValueError: Expects \"" << name << "\" to be \"" - << value << "\", but gets: " << (*iter).second; + TVM_FFI_CHECK(str && str.value() == value, ValueError) + << "Expects \"" << name << "\" to be \"" << value << "\", but gets: " << (*iter).second; } } @@ -167,8 +167,8 @@ ffi::Map UpdateCUDAAttrs(ffi::Map if (target.count("arch")) { // If -arch has been specified, validate the correctness ffi::String archStr = Downcast(target.at("arch")); - ICHECK(support::StartsWith(archStr, "sm_")) - << "ValueError: CUDA target gets an invalid CUDA arch: -arch=" << archStr; + TVM_FFI_CHECK(support::StartsWith(archStr, "sm_"), ValueError) + << "CUDA target gets an invalid CUDA arch: -arch=" << archStr; } else { // Use the compute version of the first CUDA GPU instead int archInt; @@ -195,8 +195,8 @@ ffi::Map UpdateNVPTXAttrs(ffi::Map if (target.count("mcpu")) { // If -mcpu has been specified, validate the correctness ffi::String mcpu = Downcast(target.at("mcpu")); - ICHECK(support::StartsWith(mcpu, "sm_")) - << "ValueError: NVPTX target gets an invalid CUDA arch: -mcpu=" << mcpu; + TVM_FFI_CHECK(support::StartsWith(mcpu, "sm_"), ValueError) + << "NVPTX target gets an invalid CUDA arch: -mcpu=" << mcpu; } else { // Use the compute version of the first CUDA GPU instead int arch; @@ -224,7 +224,8 @@ ffi::Map UpdateROCmAttrs(ffi::Map if (target.count("mcpu")) { ffi::String mcpu = Downcast(target.at("mcpu")); arch = ExtractStringWithPrefix(mcpu, "gfx"); - ICHECK(!arch.empty()) << "ValueError: ROCm target gets an invalid GFX version: -mcpu=" << mcpu; + TVM_FFI_CHECK(!arch.empty(), ValueError) + << "ROCm target gets an invalid GFX version: -mcpu=" << mcpu; } else { ffi::Any val; if (const auto f_get_rocm_arch = tvm::ffi::Function::GetGlobal("tvm_callback_rocm_get_arch")) { diff --git a/src/target/virtual_device.cc b/src/target/virtual_device.cc index 6c797143f020..cd9d8f4ead92 100644 --- a/src/target/virtual_device.cc +++ b/src/target/virtual_device.cc @@ -68,7 +68,7 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) VirtualDevice::VirtualDevice(int device_type_int, int virtual_device_id, Target target, MemoryScope memory_scope) { - ICHECK(!target.defined() || device_type_int == target->GetTargetDeviceType()) + TVM_FFI_ICHECK(!target.defined() || device_type_int == target->GetTargetDeviceType()) << "target " << target->str() << " has device type " << target->GetTargetDeviceType() << " but virtual device has device type " << device_type_int; auto node = ffi::make_object(); @@ -179,9 +179,9 @@ VirtualDevice VirtualDeviceCache::Make(int device_type, int virtual_device_id, T cache_.emplace(prototype); return prototype; } else { - ICHECK_EQ(prototype->target.defined(), (*itr)->target.defined()); + TVM_FFI_ICHECK_EQ(prototype->target.defined(), (*itr)->target.defined()); if (prototype->target.defined()) { - ICHECK_EQ(prototype->target->host.defined(), (*itr)->target->host.defined()); + TVM_FFI_ICHECK_EQ(prototype->target->host.defined(), (*itr)->target->host.defined()); } return *itr; } diff --git a/src/te/operation/compute_op.cc b/src/te/operation/compute_op.cc index fa7424a7cda0..61abb610183a 100644 --- a/src/te/operation/compute_op.cc +++ b/src/te/operation/compute_op.cc @@ -65,27 +65,29 @@ static inline void AssertReduceEqual(const tir::ReduceNode* a, const tir::Reduce StructuralEqual eq; - ICHECK(a->combiner.same_as(b->combiner)) << shared_text << "However, the reduction operation " - << a->combiner << " does not match " << b->combiner; - ICHECK(a->source.same_as(b->source)) + TVM_FFI_ICHECK(a->combiner.same_as(b->combiner)) + << shared_text << "However, the reduction operation " << a->combiner << " does not match " + << b->combiner; + TVM_FFI_ICHECK(a->source.same_as(b->source)) << shared_text << "However, the input " << a->source << " does not match " << b->source; - ICHECK(eq(a->axis, b->axis)) << shared_text << "However, the reduction axis " << a->axis - << " does not match " << b->axis; - ICHECK(eq(a->condition, b->condition)) << shared_text << "However, the predicate " << a->condition - << " does not match " << b->condition; - ICHECK(eq(a->init, b->init)) << shared_text << "However, the initial value " << a->init - << " does not match " << b->init; + TVM_FFI_ICHECK(eq(a->axis, b->axis)) + << shared_text << "However, the reduction axis " << a->axis << " does not match " << b->axis; + TVM_FFI_ICHECK(eq(a->condition, b->condition)) + << shared_text << "However, the predicate " << a->condition << " does not match " + << b->condition; + TVM_FFI_ICHECK(eq(a->init, b->init)) + << shared_text << "However, the initial value " << a->init << " does not match " << b->init; } int ComputeOpNode::num_outputs() const { return body.size(); } DataType ComputeOpNode::output_dtype(size_t idx) const { - ICHECK_LT(idx, num_outputs()); + TVM_FFI_ICHECK_LT(idx, num_outputs()); return body[idx].dtype(); } ffi::Array BaseComputeOpNode::output_shape(size_t idx) const { - ICHECK_LT(idx, num_outputs()); + TVM_FFI_ICHECK_LT(idx, num_outputs()); // for now, all outputs of a BaseComputeOp have the same shape ffi::Array shape; for (const auto& ivar : this->axis) { @@ -210,8 +212,9 @@ class ComputeVerifier final : protected tir::ExprVisitor { for (const PrimExpr e : compute_->body) { // Check for consistency of top level reductions const tir::ReduceNode* reduce = e.as(); - ICHECK((reduce && reduce_) || (!reduce && !reduce_)) << "All ComputeOp should be consistent " - << "with being Reduce operation or not."; + TVM_FFI_ICHECK((reduce && reduce_) || (!reduce && !reduce_)) + << "All ComputeOp should be consistent " + << "with being Reduce operation or not."; if (reduce && reduce_) { AssertReduceEqual(reduce, reduce_); @@ -233,8 +236,8 @@ class ComputeVerifier final : protected tir::ExprVisitor { void VisitExpr_(const tir::ReduceNode* op) final { // Check for non top level reductions - ICHECK(0 == level_) << "Reductions are only allowed at the top level of compute. " - << "Please create another tensor for further composition."; + TVM_FFI_ICHECK(0 == level_) << "Reductions are only allowed at the top level of compute. " + << "Please create another tensor for further composition."; } //@} diff --git a/src/te/operation/create_primfunc.cc b/src/te/operation/create_primfunc.cc index f1ac7358cd8b..0734310c30d0 100644 --- a/src/te/operation/create_primfunc.cc +++ b/src/te/operation/create_primfunc.cc @@ -53,7 +53,7 @@ class ProducerToBufferTransformer : public StmtExprMutator { auto visited_op = Downcast(StmtExprMutator::VisitExpr_(op)); te::Tensor tensor = Downcast(visited_op->producer); auto it = tensor2buffers_.find(tensor); - ICHECK(it != tensor2buffers_.end()) << "IndexError: Cannot find the tensor " << tensor; + TVM_FFI_CHECK(it != tensor2buffers_.end(), IndexError) << "Cannot find the tensor " << tensor; const Buffer& buffer = it->second; return BufferLoad(buffer, visited_op->indices); } @@ -262,8 +262,8 @@ ffi::Array GenerateOutputBuffers(const te::ComputeOp& compute_op, Create // specially handle reduction inline for multiplre reductions. for (size_t k = 1; k < compute_op->body.size(); ++k) { const tir::ReduceNode* reduce_ = compute_op->body[k].as(); - ICHECK(reduce_); - ICHECK(f_reducer_equal(reduce_, reduce)) + TVM_FFI_ICHECK(reduce_); + TVM_FFI_ICHECK(f_reducer_equal(reduce_, reduce)) << "The Reduce inputs of ComputeOp should have the same attribute except value_index, " << "but the first argument has body " << ffi::GetRef(reduce_) << ", while the " << k << "-th argument has body " << ffi::GetRef(reduce); @@ -386,7 +386,7 @@ Stmt GenerateBodyStmt(const ffi::Array& indices, const ffi::ArraySimplify(f_transform_and_remap(reduce->source[i])); lhs.push_back(left); rhs.push_back(right); - ICHECK_EQ(left->dtype, right->dtype); + TVM_FFI_ICHECK_EQ(left->dtype, right->dtype); } ffi::Array temp_vars; @@ -421,7 +421,7 @@ Stmt GenerateBodyStmt(const ffi::Array& indices, const ffi::ArraySimplify(compute_body), indices); } @@ -481,7 +481,7 @@ Stmt GenerateStmtFromCompute(const te::ComputeOp& compute_op, CreateFuncInfo* in // For each axis, we generate loop and the first block binding at the level it belongs to. // In lower levels, we just create new block var and bind it to the previous level block var. auto axes_levels = GenerateNestedIterLevels(axes, analyzer); - ICHECK(!axes_levels.empty()); + TVM_FFI_ICHECK(!axes_levels.empty()); std::vector scopes; scopes.reserve(axes_levels.size()); std::unordered_set defined_axes; @@ -509,8 +509,8 @@ Stmt GenerateStmtFromCompute(const te::ComputeOp& compute_op, CreateFuncInfo* in cur_scope.AddBlockIter(axis, new_block_iter, loop_var); defined_axes.insert(axis->var); } else if (defined_axes.count(axis->var)) { - ICHECK_GT(i, 0); - ICHECK(scopes[i - 1].axes_remap.count(axis->var)); + TVM_FFI_ICHECK_GT(i, 0); + TVM_FFI_ICHECK(scopes[i - 1].axes_remap.count(axis->var)); PrimExpr prev_binding = scopes[i - 1].axes_remap.at(axis->var); Var block_var("v_" + axis->var->name_hint, index_type); Range dom = Range::FromMinExtent(prev_binding, make_const(index_type, 1)); @@ -620,18 +620,18 @@ Stmt GenerateStmtFromExternOp(const te::ExternOp& extern_op, CreateFuncInfo* inf // Step 1. Check all inputs are visited before and update var_map. std::unordered_map var_map; std::unordered_map input_buffer_map; - ICHECK_EQ(extern_op->inputs.size(), extern_op->input_placeholders.size()); + TVM_FFI_ICHECK_EQ(extern_op->inputs.size(), extern_op->input_placeholders.size()); for (size_t i = 0; i < extern_op->inputs.size(); ++i) { const Buffer& placeholder = extern_op->input_placeholders[i]; const te::Tensor& input_tensor = extern_op->inputs[i]; auto it = info->tensor2buffers.find(input_tensor); - ICHECK(it != info->tensor2buffers.end()); + TVM_FFI_ICHECK(it != info->tensor2buffers.end()); var_map[placeholder->data.get()] = it->second->data; input_buffer_map[placeholder.get()] = it->second; } // Step 2. Update info with its output tensor and placeholder buffer. - ICHECK_EQ(extern_op->num_outputs(), extern_op->output_placeholders.size()); + TVM_FFI_ICHECK_EQ(extern_op->num_outputs(), extern_op->output_placeholders.size()); for (int i = 0; i < extern_op->num_outputs(); ++i) { const Buffer& placeholder = extern_op->output_placeholders[i]; const te::Tensor& output_tensor = extern_op.output(i); @@ -680,8 +680,8 @@ ffi::Array CollectOrderedOps(const ffi::Array& arg_li for (const te::Operation& op : order) { if (!(op->IsInstance() || op->IsInstance() || op->IsInstance())) - LOG(FATAL) << "TypeError: Unsupported Operation: " << op->GetTypeKey() << ". " - << "Only te.placeholder and te.compute are allowed for now."; + TVM_FFI_THROW(TypeError) << "Unsupported Operation: " << op->GetTypeKey() << ". " + << "Only te.placeholder and te.compute are allowed for now."; } return order; } @@ -691,7 +691,7 @@ void InitializeBufferBinds(const ffi::Array& ordered_ops, CreateF for (const auto& op : ordered_ops) { // Initialize the tensor2buffer binds map with buffers defined by the te.extern if (const auto* extern_op = op.as()) { - ICHECK_EQ(extern_op->inputs.size(), extern_op->input_placeholders.size()); + TVM_FFI_ICHECK_EQ(extern_op->inputs.size(), extern_op->input_placeholders.size()); for (size_t i = 0; i < extern_op->inputs.size(); ++i) { const te::Tensor& input = extern_op->inputs[i]; const Buffer& buffer = extern_op->input_placeholders[i]; @@ -705,12 +705,13 @@ void RewriteStageToBlock(const te::Operation& op, CreateFuncInfo* info, ffi::Array* root_stmts, arith::Analyzer* analyzer) { if (const auto* placeholder = op.as()) { // Case 1. PlaceholderOp (te.placeholder) - ICHECK_EQ(op->num_outputs(), 1); + TVM_FFI_ICHECK_EQ(op->num_outputs(), 1); const te::Tensor& tensor = op.output(0); // Check op is in op list - ICHECK(info->IsArg(tensor)) << "The operation " << op << " produces tensor " << tensor - << ", but this tensor does not appear as a function argument. " - << "The function accepts arguments " << info->arg_list; + TVM_FFI_ICHECK(info->IsArg(tensor)) + << "The operation " << op << " produces tensor " << tensor + << ", but this tensor does not appear as a function argument. " + << "The function accepts arguments " << info->arg_list; // Declare a buffer for any argument tensors without a pre-existing // buffer declaration recorded in the tensor2buffer binds map if (info->tensor2buffers.count(tensor) == 0) { @@ -725,8 +726,8 @@ void RewriteStageToBlock(const te::Operation& op, CreateFuncInfo* info, // Case 3. ExternOp (te.extern) root_stmts->push_back(GenerateStmtFromExternOp(extern_op.value(), info)); } else { - ICHECK(false) << "TypeError: Unsupported Operation: " << op->GetTypeKey() << ". " - << "Only te.placeholder and te.compute are allowed for now."; + TVM_FFI_CHECK(false, TypeError) << "Unsupported Operation: " << op->GetTypeKey() << ". " + << "Only te.placeholder and te.compute are allowed for now."; } } @@ -738,7 +739,7 @@ PrimFunc GenerateAndCompletePrimFunc(const ffi::Array& arg_list, Var arg("var_" + tensor->GetNameHint(), PrimType(DataType::Handle())); parameters.push_back(arg); auto it = info->tensor2buffers.find(tensor); - ICHECK(it != info->tensor2buffers.end()); + TVM_FFI_ICHECK(it != info->tensor2buffers.end()); buffer_map.Set(arg, it->second); } PrimFunc func = WithAttrs(PrimFunc(/*params=*/std::move(parameters), @@ -747,7 +748,7 @@ PrimFunc GenerateAndCompletePrimFunc(const ffi::Array& arg_list, /*buffer_map=*/std::move(buffer_map)), {{"global_symbol", ffi::String("main")}, {"tir.noalias", true}}); const auto fcomplete = tvm::ffi::Function::GetGlobal("script.Complete"); - ICHECK(fcomplete.has_value()); + TVM_FFI_ICHECK(fcomplete.has_value()); func = (*fcomplete)(std::move(func), info->root_alloc).cast(); return func; } @@ -805,7 +806,7 @@ PrimFunc GenerateAndCompletePrimFunc(const ffi::Array& arg_tir_var_li Var arg("var_" + tensor->GetNameHint(), PrimType(DataType::Handle())); parameters.push_back(arg); auto it = info->tensor2buffers.find(tensor); - ICHECK(it != info->tensor2buffers.end()); + TVM_FFI_ICHECK(it != info->tensor2buffers.end()); buffer_map.Set(arg, it->second); } else if (auto var = arg.as()) { parameters.push_back(var.value()); @@ -817,7 +818,7 @@ PrimFunc GenerateAndCompletePrimFunc(const ffi::Array& arg_tir_var_li /*buffer_map=*/std::move(buffer_map)), {{"global_symbol", ffi::String("main")}, {"tir.noalias", true}}); const auto fcomplete = tvm::ffi::Function::GetGlobal("script.Complete"); - ICHECK(fcomplete.has_value()); + TVM_FFI_ICHECK(fcomplete.has_value()); func = (*fcomplete)(std::move(func), info->root_alloc).cast(); return func; } diff --git a/src/te/operation/extern_op.cc b/src/te/operation/extern_op.cc index def64595412d..b15ae9f67624 100644 --- a/src/te/operation/extern_op.cc +++ b/src/te/operation/extern_op.cc @@ -58,14 +58,14 @@ ExternOp::ExternOp(std::string name, std::string tag, ffi::Mapname = std::move(name); n->tag = std::move(tag); n->attrs = std::move(attrs); - ICHECK_EQ(inputs.size(), input_placeholders.size()); + TVM_FFI_ICHECK_EQ(inputs.size(), input_placeholders.size()); for (size_t i = 0; i < inputs.size(); ++i) { - ICHECK_EQ(inputs[i]->dtype, input_placeholders[i]->dtype); - ICHECK_EQ(inputs[i]->shape.size(), input_placeholders[i]->shape.size()); + TVM_FFI_ICHECK_EQ(inputs[i]->dtype, input_placeholders[i]->dtype); + TVM_FFI_ICHECK_EQ(inputs[i]->shape.size(), input_placeholders[i]->shape.size()); for (size_t dim = 0; dim < inputs[i]->shape.size(); ++dim) { - ICHECK(inputs[i]->shape[dim].same_as(input_placeholders[i]->shape[dim])); + TVM_FFI_ICHECK(inputs[i]->shape[dim].same_as(input_placeholders[i]->shape[dim])); } - ICHECK_EQ(input_placeholders[i]->strides.size(), 0U); + TVM_FFI_ICHECK_EQ(input_placeholders[i]->strides.size(), 0U); } n->inputs = std::move(inputs); n->input_placeholders = std::move(input_placeholders); diff --git a/src/te/operation/placeholder_op.cc b/src/te/operation/placeholder_op.cc index 6c7d60841c0f..a063c8304572 100644 --- a/src/te/operation/placeholder_op.cc +++ b/src/te/operation/placeholder_op.cc @@ -41,12 +41,12 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) int PlaceholderOpNode::num_outputs() const { return 1; } DataType PlaceholderOpNode::output_dtype(size_t i) const { - ICHECK_EQ(i, 0U); + TVM_FFI_ICHECK_EQ(i, 0U); return dtype; } ffi::Array PlaceholderOpNode::output_shape(size_t i) const { - ICHECK_EQ(i, 0U); + TVM_FFI_ICHECK_EQ(i, 0U); return shape; } @@ -72,7 +72,7 @@ TVM_FFI_STATIC_INIT_BLOCK() { } else if (auto arg_array = shape_arg.as>()) { return arg_array.value(); } else { - LOG(FATAL) << "Variant did not contain either allowed type"; + TVM_FFI_THROW(InternalError) << "Variant did not contain either allowed type"; } }(); return placeholder(shape, dtype, name); diff --git a/src/te/operation/scan_op.cc b/src/te/operation/scan_op.cc index fbc65e8a61fb..25d09c931a22 100644 --- a/src/te/operation/scan_op.cc +++ b/src/te/operation/scan_op.cc @@ -43,7 +43,7 @@ int ScanOpNode::num_outputs() const { return static_cast(update.size()); } DataType ScanOpNode::output_dtype(size_t i) const { return update[i]->dtype; } ffi::Array ScanOpNode::output_shape(size_t i) const { - ICHECK_LT(i, state_placeholder.size()); + TVM_FFI_ICHECK_LT(i, state_placeholder.size()); return state_placeholder[i]->shape; } @@ -55,27 +55,27 @@ ScanOp::ScanOp(std::string name, std::string tag, attrs = ffi::Map(); } auto n = ffi::make_object(); - ICHECK_EQ(init.size(), update.size()); - ICHECK_EQ(init.size(), state_placeholder.size()); + TVM_FFI_ICHECK_EQ(init.size(), update.size()); + TVM_FFI_ICHECK_EQ(init.size(), state_placeholder.size()); arith::Analyzer analyzer; auto prove_equal = [&](PrimExpr lhs, PrimExpr rhs) { return is_zero(analyzer.Simplify(lhs - rhs)); }; for (size_t i = 0; i < init.size(); ++i) { - ICHECK_EQ(init[i]->dtype, state_placeholder[i]->dtype); - ICHECK_EQ(init[i]->dtype, update[i]->dtype); - ICHECK(prove_equal(init[i]->shape[0], axis->dom->min)) + TVM_FFI_ICHECK_EQ(init[i]->dtype, state_placeholder[i]->dtype); + TVM_FFI_ICHECK_EQ(init[i]->dtype, update[i]->dtype); + TVM_FFI_ICHECK(prove_equal(init[i]->shape[0], axis->dom->min)) << "init.shape[0] need to match scan_axis.dom.min"; - ICHECK(prove_equal(state_placeholder[i]->shape[0], axis->dom->min + axis->dom->extent)) + TVM_FFI_ICHECK(prove_equal(state_placeholder[i]->shape[0], axis->dom->min + axis->dom->extent)) << "state_placeholder.shape[0] need to match" << " scan_axis.dom.min + scan_axis.dom.extent"; - ICHECK_EQ(state_placeholder[i].ndim(), init[i].ndim()) + TVM_FFI_ICHECK_EQ(state_placeholder[i].ndim(), init[i].ndim()) << "The dimension of init need to match state_placeholder"; - ICHECK_EQ(update[i].ndim(), state_placeholder[i].ndim()) + TVM_FFI_ICHECK_EQ(update[i].ndim(), state_placeholder[i].ndim()) << "The update.ndim need to be state_placeholder.ndim - 1"; for (size_t k = 0; k < update[i].ndim(); ++k) { - ICHECK(prove_equal(update[i]->shape[k], state_placeholder[i]->shape[k])); + TVM_FFI_ICHECK(prove_equal(update[i]->shape[k], state_placeholder[i]->shape[k])); if (k != 0) { // setup spatial axis std::ostringstream spatial_name; @@ -86,7 +86,7 @@ ScanOp::ScanOp(std::string name, std::string tag, } for (size_t k = 1; k < init[i].ndim(); ++k) { - ICHECK(prove_equal(init[i]->shape[k], state_placeholder[i]->shape[k])); + TVM_FFI_ICHECK(prove_equal(init[i]->shape[k], state_placeholder[i]->shape[k])); } } n->name = std::move(name); diff --git a/src/te/tensor.cc b/src/te/tensor.cc index 8035564b27f4..031f2a0aba09 100644 --- a/src/te/tensor.cc +++ b/src/te/tensor.cc @@ -56,7 +56,7 @@ inline PrimExpr Tensor::IndexTensor(ffi::Array indices, ffi::Array shape = (*this)->shape; if (shape.size() != 0) { - ICHECK_EQ(shape.size(), indices.size()) + TVM_FFI_ICHECK_EQ(shape.size(), indices.size()) << "Tensor dimension mismatch in read " << "ndim = " << ndim() << ", indices.size=" << indices.size(); } diff --git a/src/tir/analysis/control_flow_graph.cc b/src/tir/analysis/control_flow_graph.cc index f498c039ef9e..8d7a13b0b5c7 100644 --- a/src/tir/analysis/control_flow_graph.cc +++ b/src/tir/analysis/control_flow_graph.cc @@ -66,7 +66,7 @@ bool HasBufferLoad(PrimExpr expr) { ffi::Optional SubstituteParamValues(const ffi::Array& param_vars, const ffi::Array& param_values, const PrimExpr& expr) { - ICHECK_EQ(param_vars.size(), param_values.size()) + TVM_FFI_ICHECK_EQ(param_vars.size(), param_values.size()) << "Expression was defined as having " << param_vars.size() << " parameters, but received " << param_values.size() << " arguments."; @@ -170,7 +170,7 @@ class BufferConstraintApply : public IRMutatorWithAnalyzer { if (index.dtype().lanes() == 1) { return index; } else { - ICHECK(!lane_var) << "Multiple indices found with non-scalar values"; + TVM_FFI_ICHECK(!lane_var) << "Multiple indices found with non-scalar values"; lane_var = Var("lane", index.dtype().element_of()); num_lanes = IntImm(index.dtype().element_of(), index.dtype().lanes()); return UnwrapVectorExpr(index, lane_var.value()); @@ -265,8 +265,9 @@ class ControlFlowGraphBuilder final : public IRVisitorWithAnalyzer { } else if (side_effect == tir::CallEffectKind::kReadState) { buffer_exprs.push_back(expr); } else { - LOG(FATAL) << "Assumption must be pure or read-only, but contained expression " << expr - << " with side-effect \'" << side_effect << "\'"; + TVM_FFI_THROW(InternalError) + << "Assumption must be pure or read-only, but contained expression " << expr + << " with side-effect \'" << side_effect << "\'"; } } @@ -275,10 +276,11 @@ class ControlFlowGraphBuilder final : public IRVisitorWithAnalyzer { return; } - CHECK_EQ(buffer_exprs.size(), 1) << "T.assume must contain only a single buffer expression"; + TVM_FFI_ICHECK_EQ(buffer_exprs.size(), 1) + << "T.assume must contain only a single buffer expression"; auto* as_equal_node = buffer_exprs[0].as(); - CHECK(as_equal_node || !from_assume_statement) + TVM_FFI_ICHECK(as_equal_node || !from_assume_statement) << "T.assume buffer constraint must be of the form 'buffer[indices] == " "value', but received " << assumption; @@ -300,11 +302,12 @@ class ControlFlowGraphBuilder final : public IRVisitorWithAnalyzer { } else if (!from_assume_statement) { return; } else { - LOG(FATAL) << "T.assume buffer constraint must be of the form 'buffer[indices] == value'"; + TVM_FFI_THROW(InternalError) + << "T.assume buffer constraint must be of the form 'buffer[indices] == value'"; } auto has_side_effect = tir::SideEffect(value) > tir::CallEffectKind::kPure; - CHECK(!has_side_effect || !from_assume_statement) + TVM_FFI_ICHECK(!has_side_effect || !from_assume_statement) << "Buffer value in constraint must be pure expression, but was " << value; if (has_side_effect) { return; @@ -518,8 +521,8 @@ class ControlFlowGraphBuilder final : public IRVisitorWithAnalyzer { */ std::pair MarkControlFlow( size_t from_block, size_t to_block) { - ICHECK_LE(from_block, out_->control_flow_.size()); - ICHECK_LE(to_block, out_->control_flow_.size()); + TVM_FFI_ICHECK_LE(from_block, out_->control_flow_.size()); + TVM_FFI_ICHECK_LE(to_block, out_->control_flow_.size()); auto& forward = out_->control_flow_[from_block].successors.emplace_back( ControlFlowGraph::ControlFlowEdge{to_block, {}, std::nullopt}); @@ -544,7 +547,7 @@ class ControlFlowGraphBuilder final : public IRVisitorWithAnalyzer { new_num_constraints = self->conditions_.size(); } ~InternalConstraintContext() { - ICHECK_EQ(self->conditions_.size(), new_num_constraints) + TVM_FFI_ICHECK_EQ(self->conditions_.size(), new_num_constraints) << "Internal error: Each condition should only be popped once."; self->conditions_.erase(self->conditions_.begin() + old_num_constraints, self->conditions_.end()); @@ -649,7 +652,7 @@ std::pair> ControlFlowGraph::ControlFlowBlock: if (index.dtype().lanes() == 1) { return index; } else { - ICHECK(!lane_var) << "Multiple indices found with non-scalar values"; + TVM_FFI_ICHECK(!lane_var) << "Multiple indices found with non-scalar values"; lane_var = Var("lane", index.dtype().element_of()); num_lanes = IntImm(index.dtype().element_of(), index.dtype().lanes()); return UnwrapVectorExpr(index, lane_var.value()); @@ -673,7 +676,7 @@ std::pair> ControlFlowGraph::ControlFlowBlock: } IntConstraintsTransform transform = [&]() { - ICHECK_EQ(index_variables.size(), index_expressions.size()); + TVM_FFI_ICHECK_EQ(index_variables.size(), index_expressions.size()); ffi::Array relations; @@ -782,7 +785,7 @@ std::pair> ControlFlowGraph::ControlFlowBlock: std::vector> loop_var_expressions; for (const auto& entry : current_block.active_loop_iterators) { auto expr_it = loop_var_to_axis_var.find(entry.loop_var); - ICHECK(expr_it != loop_var_to_axis_var.end()); + TVM_FFI_ICHECK(expr_it != loop_var_to_axis_var.end()); loop_var_expressions.push_back({entry.loop_var, (*expr_it).second}); } @@ -811,7 +814,7 @@ BufferTouch ControlFlowGraph::ControlFlowBlock::MakeBufferTouch(ControlFlowGraph const ffi::Array& indices, BufferTouch::AccessType touch_type, PrimExpr known_value_expr) const { - ICHECK(graph); + TVM_FFI_ICHECK(graph); auto [buffer_touch, free_params] = MakeBufferTouch(buf, graph->GetIndexVariables(buf, indices), indices, touch_type, known_value_expr); for (const auto& pair : free_params) { @@ -831,7 +834,7 @@ ControlFlowGraph::ControlFlowGraph(const tir::Stmt& stmt, int64_t max_simplifica void ControlFlowGraph::RemoveStore(const tir::BufferStore& store) { size_t context_index = [&]() { auto it = control_flow_lookup_.find(store.get()); - ICHECK(it != control_flow_lookup_.end()) + TVM_FFI_ICHECK(it != control_flow_lookup_.end()) << "BufferStore did not occur in the Stmt provided to BufferTouchPattern's constructor"; return it->second; }(); @@ -1405,8 +1408,8 @@ void ControlFlowGraph::ForwardPropagateKnownValues(std::optional flow_fr // Validate internal constraint. This should be true by // construction, as ControlFlowGraphBuilder only builds graphs // that have two or fewer predecessors. - ICHECK_LE(block.predecessors.size(), 2) - << "InternalError: Each block should have at most two predecessors. " + TVM_FFI_CHECK_LE(block.predecessors.size(), 2, InternalError) + << "Each block should have at most two predecessors. " << "Graph constructed in ControlFlowGraphBuilder did not satisfy this constraint."; std::vector states; @@ -1535,7 +1538,7 @@ void ControlFlowGraph::BackwardPropagateUnusedValues(std::optional flow_ if (num_previous_visits >= max_revisits_) { return BufferState(); } - ICHECK_LE(block.successors.size(), 2) + TVM_FFI_ICHECK_LE(block.successors.size(), 2) << "Each block should have at most two successors, but block " << visiting << " breaks this requirement"; @@ -1627,8 +1630,9 @@ bool ControlFlowGraph::IsOverwrittenWithoutEffect(const tir::BufferStore& store, } auto it = control_flow_lookup_.find(context.get()); - ICHECK(it != control_flow_lookup_.end()) << "Context did not occur within analyzed statement:\n" - << context; + TVM_FFI_ICHECK(it != control_flow_lookup_.end()) + << "Context did not occur within analyzed statement:\n" + << context; const auto& context_block = control_flow_[it->second]; auto [store_touch, free_params] = context_block.MakeBufferTouch( @@ -1663,7 +1667,7 @@ PrimExpr ControlFlowGraph::SimplifyInContext(PrimExpr expr, const tir::Stmt& con Analyzer* analyzer) const { size_t context_index = [&]() { auto it = control_flow_lookup_.find(context.get()); - ICHECK(it != control_flow_lookup_.end()) + TVM_FFI_ICHECK(it != control_flow_lookup_.end()) << "Context did not occur in the Stmt provided to BufferTouchPattern's constructor"; return it->second; }(); diff --git a/src/tir/analysis/var_use_def_analysis.cc b/src/tir/analysis/var_use_def_analysis.cc index 360dbb6e445d..8a009ce4016c 100644 --- a/src/tir/analysis/var_use_def_analysis.cc +++ b/src/tir/analysis/var_use_def_analysis.cc @@ -37,7 +37,7 @@ VarUseDefAnalyzer::VarUseDefAnalyzer(const ffi::Array& defined_vars, bool v void VarUseDefAnalyzer::VisitStmt_(const AttrStmtNode* op) { if (op->attr_key == attr::thread_extent) { IterVar iv = Downcast(op->node); - ICHECK_NE(iv->thread_tag.length(), 0U); + TVM_FFI_ICHECK_NE(iv->thread_tag.length(), 0U); // thread_extent can appear multiple times // use the first appearance as def. if (!use_count_.count(iv->var.get())) { @@ -89,7 +89,7 @@ void VarUseDefAnalyzer::VisitExpr_(const LetNode* op) { auto it = let_binding_.find(op->var.get()); this->VisitExpr(op->value); if (it != let_binding_.end()) { - ICHECK(deep_equal_(it->second->value, op->value)) + TVM_FFI_ICHECK(deep_equal_(it->second->value, op->value)) << "Let cannot bind the same var to two different values"; } else { this->HandleDef(op->var); @@ -130,10 +130,10 @@ void VarUseDefAnalyzer::VisitBuffer(const Buffer& buffer) { void VarUseDefAnalyzer::HandleDef(const Var& var) { auto v = var.get(); - ICHECK(!def_count_.count(v)) << "variable " << v->name_hint - << " has already been defined, the Stmt is not SSA"; - ICHECK(!use_count_.count(v)) << "variable " << v->name_hint - << " has been used before definition!"; + TVM_FFI_ICHECK(!def_count_.count(v)) + << "variable " << v->name_hint << " has already been defined, the Stmt is not SSA"; + TVM_FFI_ICHECK(!use_count_.count(v)) + << "variable " << v->name_hint << " has been used before definition!"; use_count_[v] = 0; def_count_[v] = 1; } @@ -153,9 +153,9 @@ void VarUseDefAnalyzer::HandleUse(const Var& var) { void VarUseDefAnalyzer::HandleDef(const Buffer& buf) { auto ptr = buf.get(); - ICHECK(!buffer_def_count_.count(ptr)) + TVM_FFI_ICHECK(!buffer_def_count_.count(ptr)) << "buffer " << ptr->name << " has already been defined, the Stmt is not SSA"; - ICHECK(!buffer_use_count_.count(ptr)) + TVM_FFI_ICHECK(!buffer_use_count_.count(ptr)) << "buffer " << ptr->name << " has been used before definition!"; buffer_use_count_[ptr] = 0; buffer_def_count_[ptr] = 1; @@ -205,7 +205,8 @@ TVM_FFI_STATIC_INIT_BLOCK() { } else if (auto opt_expr = args[0].as()) { *rv = UndefinedVars(opt_expr.value(), args[1].cast>()); } else { - LOG(FATAL) << "either UndefinedVars(stmt, args) or UndefinedVars(expr, args) is expected"; + TVM_FFI_THROW(InternalError) + << "either UndefinedVars(stmt, args) or UndefinedVars(expr, args) is expected"; } }); } diff --git a/src/tir/analysis/verify_memory.cc b/src/tir/analysis/verify_memory.cc index a82de34716c8..10642f1c703d 100644 --- a/src/tir/analysis/verify_memory.cc +++ b/src/tir/analysis/verify_memory.cc @@ -169,7 +169,7 @@ class MemoryAccessVerifier final : protected StmtExprVisitor { /// Interface of VerifyMemory pass std::vector VerifyMemory_(const PrimFunc& func) { auto target = func->GetAttr(tvm::attr::kTarget); - ICHECK(target.defined()) << "VerifyMemory: Require the target attribute"; + TVM_FFI_ICHECK(target.defined()) << "VerifyMemory: Require the target attribute"; VLOG(1) << "verifying memory for target '" << target.value()->str() << "' for primitive:" << std::endl @@ -204,9 +204,9 @@ Pass VerifyMemory() { for (auto& err : errs) { s << " " << err << "\n"; } - LOG(FATAL) << "RuntimeError: Memory verification failed with the following errors:\n" - << s.str() << " Did you forget to bind?\n" - << func.value(); + TVM_FFI_THROW(RuntimeError) << "Memory verification failed with the following errors:\n" + << s.str() << " Did you forget to bind?\n" + << func.value(); } } } diff --git a/src/tir/analysis/verify_ssa.cc b/src/tir/analysis/verify_ssa.cc index eafe28bd63a9..31c65917a3dc 100644 --- a/src/tir/analysis/verify_ssa.cc +++ b/src/tir/analysis/verify_ssa.cc @@ -151,7 +151,8 @@ Pass VerifySSA() { auto pass_func = [=](IRModule mod, PassContext ctx) { for (auto kv : mod->functions) { if (auto func = kv.second.as()) { - ICHECK(VerifySSA(func.value())) << "RuntimeError: IR is not in SSA form" << func.value(); + TVM_FFI_CHECK(VerifySSA(func.value()), RuntimeError) + << "IR is not in SSA form" << func.value(); } } return mod; diff --git a/src/tir/analysis/verify_well_formed.cc b/src/tir/analysis/verify_well_formed.cc index e50b60d55c7e..00d0ebbbcd18 100644 --- a/src/tir/analysis/verify_well_formed.cc +++ b/src/tir/analysis/verify_well_formed.cc @@ -60,18 +60,18 @@ class Verifier : protected TIRVisitorWithPath { * * Each verifier can either return a boolean, or assert on failure. * To avoid needing to duplicate this logic at every step, the - * Verify() method can be used. Similar to `LOG(FATAL)` or + * Verify() method can be used. Similar to `TVM_FFI_THROW(InternalError)` or * `LOG(DEBUG)`, it returns an object that can accept streamed * context information. * * If the error should be raised, then the context is collected - * identically to `LOG(FATAL)`. If a boolean is returned, or if the + * identically to `TVM_FFI_THROW(InternalError)`. If a boolean is returned, or if the * condition passes, then the streamed context is discarded. * * Usage: * * Verify(value == expected_value) - * << "ValueError: " << value + * << value * << " was not the expected value of " << expected_value; */ class VerifyStream { @@ -100,7 +100,7 @@ class Verifier : protected TIRVisitorWithPath { ~VerifyStream() noexcept(false) { if (log_.has_value()) { - LOG(FATAL) << log_->str(); + TVM_FFI_THROW(ValueError) << log_->str(); } } @@ -153,24 +153,25 @@ class BlockVarAccessVerifier : public StmtExprVisitor { has_error_ = true; if (assert_mode_) { if (it->second == 0) { - LOG(FATAL) << "Well-formedness check failed: " - << "Loop iterator var " << op->name_hint - << " is defined outside of any block, " - << "but is used inside the non-opaque current block \"" - << block_stack_.back()->name_hint << "\"."; + TVM_FFI_THROW(InternalError) + << "Well-formedness check failed: " + << "Loop iterator var " << op->name_hint << " is defined outside of any block, " + << "but is used inside the non-opaque current block \"" + << block_stack_.back()->name_hint << "\"."; } else { - LOG(FATAL) << "Well-formedness check failed: " - << "Loop iterator var " << op->name_hint << " is defined in block \"" - << block_stack_[it->second - 1]->name_hint << "\", " - << "but is used inside the non-opaque current block \"" - << block_stack_.back()->name_hint << "\"."; + TVM_FFI_THROW(InternalError) + << "Well-formedness check failed: " + << "Loop iterator var " << op->name_hint << " is defined in block \"" + << block_stack_[it->second - 1]->name_hint << "\", " + << "but is used inside the non-opaque current block \"" + << block_stack_.back()->name_hint << "\"."; } } } } void VisitStmt_(const ForNode* op) final { - ICHECK(loop_vars_.find(op->loop_var.get()) == loop_vars_.end()); + TVM_FFI_ICHECK(loop_vars_.find(op->loop_var.get()) == loop_vars_.end()); loop_vars_[op->loop_var.get()] = block_stack_.size(); StmtExprVisitor::VisitStmt_(op); loop_vars_.erase(op->loop_var.get()); @@ -249,8 +250,7 @@ class UndefinedVarVerifier : public Verifier { { auto it = currently_defined_.find(var); auto verify = Verify(it == currently_defined_.end() || redefine_is_allowed); - verify << "ValueError: " - << "TIR is ill-formed, " + verify << "TIR is ill-formed, " << "due to multiple nested definitions of variable " << var << "."; if (it != currently_defined_.end()) { verify << " It was first defined at " << it->second << ", and was re-defined at " << path; @@ -260,8 +260,7 @@ class UndefinedVarVerifier : public Verifier { { auto it = previously_defined_.find(var); auto verify = Verify(it == previously_defined_.end() || redefine_is_allowed); - verify << "ValueError: " - << "TIR is ill-formed, " + verify << "TIR is ill-formed, " << "due to multiple definitions of variable " << var << "."; if (it != previously_defined_.end()) { verify << " It was first defined at " << it->second << ", and was later re-defined at " @@ -284,8 +283,7 @@ class UndefinedVarVerifier : public Verifier { auto active_def = currently_defined_.find(var); auto verify = Verify(active_def != currently_defined_.end()); - verify << "ValueError: " - << "Invalid use of undefined variable " << var << " at " << path << "."; + verify << "Invalid use of undefined variable " << var << " at " << path << "."; // Check if there was a previous definition, and append the // location to the error message if there was. This is to aid in @@ -332,7 +330,6 @@ class SingleEnvThreadVerifier : public Verifier { if (auto it = env_thread_vars_.find(iter_var->thread_tag); it != env_thread_vars_.end()) { const auto& [prev_var, prev_path] = it->second; Verify(prev_var.same_as(iter_var->var)) - << "ValueError: " << "PrimFunc uses multiple distinct TIR variables " << " for the environment thread \"" << iter_var->thread_tag << "\". " << "While multiple tir::AttrStmt may define the same environment thread, " @@ -378,17 +375,18 @@ bool VerifyWellFormed(const IRModule& mod, bool assert_mode) { TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("tir.analysis.VerifyWellFormed", [](const ObjectRef& obj, - bool assert_mode) { - if (auto opt = obj.as()) { - return VerifyWellFormed(opt.value(), assert_mode); - } else if (auto opt = obj.as()) { - return VerifyWellFormed(opt.value(), assert_mode); - } else { - LOG(FATAL) << "Expected VerifyWellFormed argument to be a PrimFunc or IRModule, but found " - << obj->GetTypeKey(); - } - }); + refl::GlobalDef().def( + "tir.analysis.VerifyWellFormed", [](const ObjectRef& obj, bool assert_mode) { + if (auto opt = obj.as()) { + return VerifyWellFormed(opt.value(), assert_mode); + } else if (auto opt = obj.as()) { + return VerifyWellFormed(opt.value(), assert_mode); + } else { + TVM_FFI_THROW(InternalError) + << "Expected VerifyWellFormed argument to be a PrimFunc or IRModule, but found " + << obj->GetTypeKey(); + } + }); } } // namespace tir diff --git a/src/tir/ir/buffer.cc b/src/tir/ir/buffer.cc index f8f237013245..b569a1cb07ae 100644 --- a/src/tir/ir/buffer.cc +++ b/src/tir/ir/buffer.cc @@ -145,7 +145,7 @@ inline std::pair MergeMulModInner(arith::Analyzer* analyzer, no_opt_sum = no_opt_sum.get() ? no_opt_sum + inner_add_ptr->a : inner_add_ptr->a; search_ptr = &(inner_add_ptr->b); } else { - LOG(FATAL) << "Unexpected search result!"; + TVM_FFI_THROW(InternalError) << "Unexpected search result!"; break; } } @@ -260,13 +260,13 @@ ffi::Array Buffer::OffsetOf(ffi::Array input_indices) const // original data ignoring number of lanes. // We also perform optimization to simplify the indexing expression. ffi::Array BufferNode::ElemOffset(ffi::Array input_indices) const { - ICHECK_EQ(shape.size(), input_indices.size()) + TVM_FFI_ICHECK_EQ(shape.size(), input_indices.size()) << "Buffer " << this->name << " is " << shape.size() << "-dimensional, cannot be indexed with the " << input_indices.size() << "-dimensional indices provided."; if (strides.size()) { - ICHECK_EQ(this->strides.size(), input_indices.size()) + TVM_FFI_ICHECK_EQ(this->strides.size(), input_indices.size()) << "If strides are defined, " << "the index's dimensionality must match the dimensionality of the index given."; } @@ -280,7 +280,7 @@ ffi::Array BufferNode::ElemOffset(ffi::Array input_indices) } if (elem_offsets.size()) { - ICHECK_EQ(elem_offsets.size(), axis_separators.size() + 1) + TVM_FFI_ICHECK_EQ(elem_offsets.size(), axis_separators.size() + 1) << "If element offsets are defined, " << "there must be one element offset for each output index."; } @@ -347,20 +347,17 @@ static void ValidateAxisSeparators(const ffi::Array& axis_separators, si for (size_t i = 0; (i + 1) < axis_separators.size(); i++) { auto sep = axis_separators[i]->value; auto next_sep = axis_separators[i + 1]->value; - CHECK_LE(sep, next_sep) << "ValueError: " - << "Axis separators must be in increasing order, " - << "but axis_separators[" << i << "] = " << sep - << " is greater than or equal to axis_separators[" << (i + 1) - << "] = " << next_sep << "."; + TVM_FFI_CHECK_LE(sep, next_sep, ValueError) + << "Axis separators must be in increasing order, " + << "but axis_separators[" << i << "] = " << sep + << " is greater than or equal to axis_separators[" << (i + 1) << "] = " << next_sep << "."; } if (axis_separators.size()) { auto first_sep = axis_separators[0]->value; - CHECK_GE(first_sep, 0) << "ValueError: " - << "All axis separators must be non-negative. " - << "However, the axis_separators[0] = " << first_sep; + TVM_FFI_CHECK_GE(first_sep, 0, ValueError) << "All axis separators must be non-negative. " + << "However, the axis_separators[0] = " << first_sep; auto last_sep = axis_separators[axis_separators.size() - 1]->value; - CHECK_LE(last_sep, buffer_dim) - << "ValueError: " + TVM_FFI_CHECK_LE(last_sep, buffer_dim, ValueError) << "All axis separators must be within the range " << "0 <= sep <= buffer_dim. " << "However, the last axis_separators[" << (axis_separators.size() - 1) @@ -378,7 +375,7 @@ Buffer Buffer::GetFlattenedBuffer() const { // If strides are defined, then the extent of each flattened // buffer is the stride*size for the first input axis used for // each output axis. - ICHECK_EQ(self->shape.size(), self->strides.size()); + TVM_FFI_ICHECK_EQ(self->shape.size(), self->strides.size()); output_shape.push_back(self->strides[0] * self->shape[0]); for (const auto& sep : self->axis_separators) { output_shape.push_back(self->strides[sep->value] * self->shape[sep->value]); @@ -423,9 +420,9 @@ PrimExpr Buffer::vload(ffi::Array begin, DataType value_dtype, ffi::Optional predicate) const { // specially handle bool, stored as DataType::Int(8) const BufferNode* n = operator->(); - ICHECK(n != nullptr); - ICHECK(value_dtype.element_of() == n->dtype.element_of() && - value_dtype.get_lanes_or_vscale_factor() % n->dtype.lanes() == 0) + TVM_FFI_ICHECK(n != nullptr); + TVM_FFI_ICHECK(value_dtype.element_of() == n->dtype.element_of() && + value_dtype.get_lanes_or_vscale_factor() % n->dtype.lanes() == 0) << "Cannot load " << value_dtype << " from buffer of " << n->dtype; ffi::Array indices = begin; @@ -443,10 +440,10 @@ Stmt Buffer::vstore(ffi::Array begin, PrimExpr value, ffi::Optional predicate) const { // specially handle bool, stored as DataType::Int(8) const BufferNode* n = operator->(); - ICHECK(n != nullptr); + TVM_FFI_ICHECK(n != nullptr); DataType value_dtype = value.dtype(); - ICHECK(value_dtype.element_of() == n->dtype.element_of() && - value_dtype.get_lanes_or_vscale_factor() % n->dtype.lanes() == 0) + TVM_FFI_ICHECK(value_dtype.element_of() == n->dtype.element_of() && + value_dtype.get_lanes_or_vscale_factor() % n->dtype.lanes() == 0) << "Cannot store " << value_dtype << " to buffer of " << n->dtype; ffi::Array indices = begin; @@ -462,7 +459,7 @@ Stmt Buffer::vstore(ffi::Array begin, PrimExpr value, ffi::String Buffer::scope() const { const auto* ptr_type = (*this)->data->type_annotation.as(); - ICHECK(ptr_type) << "Buffer variable is not of pointer type"; + TVM_FFI_ICHECK(ptr_type) << "Buffer variable is not of pointer type"; if (ptr_type->storage_scope.empty()) { return "global"; } @@ -474,7 +471,7 @@ Buffer Buffer::MakeStrideView() const { if ((*this)->shape.size() == 0) return *this; std::vector temp; const BufferNode* self = operator->(); - ICHECK(self != nullptr); + TVM_FFI_ICHECK(self != nullptr); auto n = ffi::make_object(*self); PrimExpr acc = make_const(n->DefaultIndexType(), 1); for (size_t i = n->shape.size(); i != 0; --i) { @@ -489,7 +486,7 @@ Buffer Buffer::MakeStrideView() const { Buffer Buffer::MakeSlice(ffi::Array begins, ffi::Array extents) const { const BufferNode* n = operator->(); - ICHECK(n != nullptr); + TVM_FFI_ICHECK(n != nullptr); arith::Analyzer ana; begins = SimplifyArray(&ana, begins); ffi::Array elem_offset = @@ -532,7 +529,7 @@ Buffer Buffer::MakeSlice(ffi::Array begins, ffi::Array exten PrimExpr Buffer::access_ptr(int access_mask, DataType ptr_type, int content_lanes, PrimExpr offset, ffi::Optional input_extent) const { const BufferNode* self = operator->(); - ICHECK(self != nullptr); + TVM_FFI_ICHECK(self != nullptr); PrimExpr e_dtype; PrimExpr extent; if (self->shape.size() == 0) { @@ -579,11 +576,11 @@ Buffer::Buffer(Var data, DataType dtype, ffi::Array shape, ffi::Array< // TODO(Lunderberg): Use an explicit pointer cast for the data // pointer. Should be done alongside extensions to StmtExprMutator // to more easily handle buffer/buffer_var updates. - ICHECK(data->type_annotation.defined()) + TVM_FFI_ICHECK(data->type_annotation.defined()) << "Variable " << data->name_hint << " is missing a type annotation."; - ICHECK(data->type_annotation.as()) + TVM_FFI_ICHECK(data->type_annotation.as()) << "Variable " << data->name_hint << " is not a pointer."; - ICHECK(data->type_annotation.as()->element_type.as()) + TVM_FFI_ICHECK(data->type_annotation.as()->element_type.as()) << "Variable " << data->name_hint << " does not point to a primitive."; ValidateAxisSeparators(axis_separators, shape.size()); @@ -650,7 +647,7 @@ TVM_FFI_STATIC_INIT_BLOCK() { refl::GlobalDef() .def_packed("tir.Buffer", [](ffi::PackedArgs args, ffi::Any* ret) { - ICHECK_EQ(args.size(), 11); + TVM_FFI_ICHECK_EQ(args.size(), 11); auto buffer_type = args[8].cast(); BufferType type = (buffer_type == "auto_broadcast") ? kAutoBroadcast : kDefault; auto data = args[0].cast(); diff --git a/src/tir/ir/data_type_rewriter.cc b/src/tir/ir/data_type_rewriter.cc index f7f6f7256a85..7e8b2ed2a694 100644 --- a/src/tir/ir/data_type_rewriter.cc +++ b/src/tir/ir/data_type_rewriter.cc @@ -42,7 +42,7 @@ namespace tir { Stmt DataTypeLegalizer::VisitStmt_(const ForNode* op) { Stmt s = StmtExprMutator::VisitStmt_(op); op = s.as(); - ICHECK(op != nullptr) << "Expected type to be ForNode, but get " << s->GetTypeKey(); + TVM_FFI_ICHECK(op != nullptr) << "Expected type to be ForNode, but get " << s->GetTypeKey(); PrimExpr e = VisitExpr(op->loop_var); Var var = Downcast(e); auto n = CopyOnWrite(op); @@ -97,18 +97,18 @@ Stmt DataTypeLegalizer::VisitStmt_(const AttrStmtNode* op) { if (op->attr_key == attr::thread_extent || op->attr_key == attr::virtual_thread) { Stmt s = StmtExprMutator::VisitStmt_(op); op = s.as(); - ICHECK(op != nullptr) << "Expected type to be AttrStmtNode" - << ", but get " << s->GetTypeKey(); + TVM_FFI_ICHECK(op != nullptr) << "Expected type to be AttrStmtNode" + << ", but get " << s->GetTypeKey(); const IterVarNode* iv = op->node.as(); - ICHECK(iv != nullptr) << "Expected type to be IterVarNode" - << ", but get " << op->node.GetTypeKey(); + TVM_FFI_ICHECK(iv != nullptr) << "Expected type to be IterVarNode" + << ", but get " << op->node.GetTypeKey(); PrimExpr e = VisitExpr(iv->var); Var var = Downcast(e); if (ivmap_.find(iv) == ivmap_.end()) { Range dom = iv->dom; if (dom.defined()) { PrimExpr extend = dom->extent; - ICHECK(extend.dtype().is_int() && var.dtype().is_int()); + TVM_FFI_ICHECK(extend.dtype().is_int() && var.dtype().is_int()); if (var.dtype().bits() != extend.dtype().bits()) { DataType dtype = var.dtype(); dom = Range(cast(dtype, dom->min), cast(dtype, extend), dom->span); @@ -186,7 +186,7 @@ PrimExpr DataTypeLegalizer::VisitExpr_(const RampNode* op) { if (base.same_as(op->base) && stride.same_as(op->stride) && base.dtype() == stride.dtype()) { return ffi::GetRef(op); } else { - ICHECK(base.dtype().is_int() && stride.dtype().is_int()); + TVM_FFI_ICHECK(base.dtype().is_int() && stride.dtype().is_int()); int bits = std::max(base.dtype().bits(), stride.dtype().bits()); DataType dtype = base.dtype().with_bits(bits); if (base.dtype() != dtype) base = cast(dtype, base); @@ -233,8 +233,8 @@ PrimExpr DataTypeLegalizer::VisitExpr_(const CallNode* op) { PrimExpr e = StmtExprMutator::VisitExpr_(op); op = e.as(); static const Op& builtin_pow_ = Op::Get("tir.pow"); - ICHECK(op != nullptr) << "Expected type to be CallNode" - << ", but get " << e->GetTypeKey(); + TVM_FFI_ICHECK(op != nullptr) << "Expected type to be CallNode" + << ", but get " << e->GetTypeKey(); if (op->op.same_as(builtin::shift_right())) { return op->args[0] >> op->args[1]; } else if (op->op.same_as(builtin::shift_left())) { @@ -252,12 +252,12 @@ PrimExpr DataTypeLegalizer::VisitExpr_(const CallNode* op) { } else if (op->op.same_as(Op::Get("tir.clz"))) { DataType before_dtype = before->args[0]->dtype; DataType after_dtype = op->args[0]->dtype; - CHECK((before_dtype.is_int() || before_dtype.is_uint()) && - (before_dtype.bits() == 32 || before_dtype.bits() == 64)) + TVM_FFI_ICHECK((before_dtype.is_int() || before_dtype.is_uint()) && + (before_dtype.bits() == 32 || before_dtype.bits() == 64)) << "clz only supports 32 or 64 bit integer types, but get type before legalizing: " << before_dtype; - CHECK((after_dtype.is_int() || after_dtype.is_uint()) && - (after_dtype.bits() == 32 || after_dtype.bits() == 64)) + TVM_FFI_ICHECK((after_dtype.is_int() || after_dtype.is_uint()) && + (after_dtype.bits() == 32 || after_dtype.bits() == 64)) << "clz only supports 32 or 64 bit integer types, but get type after legalizing: " << after_dtype; return e - after_dtype.bits() + before_dtype.bits(); @@ -587,7 +587,7 @@ Stmt IndexDataTypeRewriter::VisitStmt_(const LetStmtNode* op) { PrimExpr value = VisitExpr(op->value); Var var = var_remap_[let_stmt->var.get()]; is_enabled_ = is_enabled; - ICHECK(value.dtype() == var.dtype()); + TVM_FFI_ICHECK(value.dtype() == var.dtype()); // No need to re-visit body return LetStmt(var, value, let_stmt->body, let_stmt->span); } @@ -680,7 +680,7 @@ bool IndexDataTypeNormalizer::CanRewriteDType(DataType dtype) const { PrimExpr IndexDataTypeNormalizer::VisitExpr_(const IntImmNode* op) { if (is_enabled_ && CanRewriteDType(op->dtype)) { - ICHECK_LE(op->value, Downcast(max_value(target_data_type_))->value); + TVM_FFI_ICHECK_LE(op->value, Downcast(max_value(target_data_type_))->value); return cast(target_data_type_, ffi::GetRef(op)); } return ffi::GetRef(op); diff --git a/src/tir/ir/expr.cc b/src/tir/ir/expr.cc index 5eee4ffd8bd5..77fccc040f43 100644 --- a/src/tir/ir/expr.cc +++ b/src/tir/ir/expr.cc @@ -86,36 +86,36 @@ TVM_FFI_STATIC_INIT_BLOCK() { [](ffi::Variant> expr) { return expr; }); } -#define TVM_DEFINE_BINOP_CONSTRUCTOR(Name) \ - Name::Name(PrimExpr a, PrimExpr b, Span span) { \ - using T = Name::ContainerType; \ - ICHECK(a.defined()) << "ValueError: a is undefined\n"; \ - ICHECK(b.defined()) << "ValueError: b is undefined\n"; \ - CHECK(a.dtype() == b.dtype()) << "TypeError: mismatched types. " << a.dtype() << " vs. " \ - << b.dtype() << "\n"; \ - ObjectPtr node = ffi::make_object(); \ - node->dtype = a.dtype(); \ - node->a = std::move(a); \ - node->b = std::move(b); \ - node->span = std::move(span); \ - data_ = std::move(node); \ +#define TVM_DEFINE_BINOP_CONSTRUCTOR(Name) \ + Name::Name(PrimExpr a, PrimExpr b, Span span) { \ + using T = Name::ContainerType; \ + TVM_FFI_CHECK(a.defined(), ValueError) << "a is undefined\n"; \ + TVM_FFI_CHECK(b.defined(), ValueError) << "b is undefined\n"; \ + TVM_FFI_CHECK(a.dtype() == b.dtype(), TypeError) \ + << "mismatched types. " << a.dtype() << " vs. " << b.dtype() << "\n"; \ + ObjectPtr node = ffi::make_object(); \ + node->dtype = a.dtype(); \ + node->a = std::move(a); \ + node->b = std::move(b); \ + node->span = std::move(span); \ + data_ = std::move(node); \ } -#define TVM_DEFINE_CMPOP_CONSTRUCTOR(Name) \ - Name::Name(PrimExpr a, PrimExpr b, Span span) { \ - using T = Name::ContainerType; \ - ICHECK(a.defined()) << "ValueError: a is undefined\n"; \ - ICHECK(b.defined()) << "ValueError: b is undefined\n"; \ - CHECK(a.dtype() == b.dtype()) << "TypeError: mismatched types. " << a.dtype() << " vs. " \ - << b.dtype() << "\n"; \ - ObjectPtr node = ffi::make_object(); \ - DataType a_dtype = a.dtype(); \ - node->dtype = \ - DataType::Bool(a_dtype.get_lanes_or_vscale_factor(), a_dtype.is_scalable_vector()); \ - node->a = std::move(a); \ - node->b = std::move(b); \ - node->span = std::move(span); \ - data_ = std::move(node); \ +#define TVM_DEFINE_CMPOP_CONSTRUCTOR(Name) \ + Name::Name(PrimExpr a, PrimExpr b, Span span) { \ + using T = Name::ContainerType; \ + TVM_FFI_CHECK(a.defined(), ValueError) << "a is undefined\n"; \ + TVM_FFI_CHECK(b.defined(), ValueError) << "b is undefined\n"; \ + TVM_FFI_CHECK(a.dtype() == b.dtype(), TypeError) \ + << "mismatched types. " << a.dtype() << " vs. " << b.dtype() << "\n"; \ + ObjectPtr node = ffi::make_object(); \ + DataType a_dtype = a.dtype(); \ + node->dtype = \ + DataType::Bool(a_dtype.get_lanes_or_vscale_factor(), a_dtype.is_scalable_vector()); \ + node->a = std::move(a); \ + node->b = std::move(b); \ + node->span = std::move(span); \ + data_ = std::move(node); \ } // Var @@ -206,11 +206,11 @@ TVM_FFI_STATIC_INIT_BLOCK() { IterVar::IterVar(Range dom, Var var, IterVarType t, ffi::String thread_tag, Span span) { ObjectPtr n = ffi::make_object(); if (dom.defined() && dom->extent.defined()) { - CHECK(dom->extent.dtype().is_int()) + TVM_FFI_ICHECK(dom->extent.dtype().is_int()) << "The dtype of the domain of an IterVar must be an integer type. However, the domain's " "dtype is " << dom->extent.dtype(); - CHECK_EQ(dom->extent.dtype(), var.dtype()) + TVM_FFI_ICHECK_EQ(dom->extent.dtype(), var.dtype()) << "The dtype of the extent of an IterVar (" << dom->extent.dtype() << ") must match its associated Var's dtype (" << var.dtype() << ")"; } @@ -247,9 +247,9 @@ TVM_FFI_STATIC_INIT_BLOCK() { // Cast Cast::Cast(DataType t, PrimExpr value, Span span) { - ICHECK(value.defined()); - ICHECK_EQ(t.get_lanes_or_vscale_factor(), value.dtype().get_lanes_or_vscale_factor()); - ICHECK(t.is_scalable_vector() == value.dtype().is_scalable_vector()); + TVM_FFI_ICHECK(value.defined()); + TVM_FFI_ICHECK_EQ(t.get_lanes_or_vscale_factor(), value.dtype().get_lanes_or_vscale_factor()); + TVM_FFI_ICHECK(t.is_scalable_vector() == value.dtype().is_scalable_vector()); ObjectPtr node = ffi::make_object(); node->dtype = t; node->value = std::move(value); @@ -395,11 +395,11 @@ TVM_FFI_STATIC_INIT_BLOCK() { // And And::And(PrimExpr a, PrimExpr b, Span span) { - ICHECK(a.defined()) << "ValueError: a is undefined"; - ICHECK(b.defined()) << "ValueError: b is undefined"; - ICHECK(a.dtype().is_bool()); - ICHECK(b.dtype().is_bool()); - ICHECK(a.dtype() == b.dtype()) << "TypeError: mismatched types"; + TVM_FFI_CHECK(a.defined(), ValueError) << "a is undefined"; + TVM_FFI_CHECK(b.defined(), ValueError) << "b is undefined"; + TVM_FFI_ICHECK(a.dtype().is_bool()); + TVM_FFI_ICHECK(b.dtype().is_bool()); + TVM_FFI_CHECK(a.dtype() == b.dtype(), TypeError) << "mismatched types"; ObjectPtr node = ffi::make_object(); node->dtype = @@ -418,11 +418,11 @@ TVM_FFI_STATIC_INIT_BLOCK() { // Or Or::Or(PrimExpr a, PrimExpr b, Span span) { - ICHECK(a.defined()) << "ValueError: a is undefined"; - ICHECK(b.defined()) << "ValueError: b is undefined"; - ICHECK(a.dtype().is_bool()); - ICHECK(b.dtype().is_bool()); - ICHECK(a.dtype() == b.dtype()) << "TypeError: mismatched types"; + TVM_FFI_CHECK(a.defined(), ValueError) << "a is undefined"; + TVM_FFI_CHECK(b.defined(), ValueError) << "b is undefined"; + TVM_FFI_ICHECK(a.dtype().is_bool()); + TVM_FFI_ICHECK(b.dtype().is_bool()); + TVM_FFI_CHECK(a.dtype() == b.dtype(), TypeError) << "mismatched types"; ObjectPtr node = ffi::make_object(); node->dtype = @@ -440,8 +440,8 @@ TVM_FFI_STATIC_INIT_BLOCK() { // Not Not::Not(PrimExpr a, Span span) { - ICHECK(a.defined()) << "ValueError: a is undefined"; - ICHECK(a.dtype().is_bool()); + TVM_FFI_CHECK(a.defined(), ValueError) << "a is undefined"; + TVM_FFI_ICHECK(a.dtype().is_bool()); ObjectPtr node = ffi::make_object(); DataType a_dtype = a.dtype(); @@ -458,15 +458,15 @@ TVM_FFI_STATIC_INIT_BLOCK() { // Select Select::Select(PrimExpr condition, PrimExpr true_value, PrimExpr false_value, Span span) { - ICHECK(condition.defined()) << "ValueError: condition is undefined"; - ICHECK(true_value.defined()) << "ValueError: true_value is undefined"; - ICHECK(false_value.defined()) << "ValueError: true_value is undefined"; - ICHECK(condition.dtype().is_bool()); - ICHECK(condition.dtype().get_lanes_or_vscale_factor() == - true_value.dtype().get_lanes_or_vscale_factor() || - condition.dtype().is_scalar()); - ICHECK(false_value.dtype() == true_value.dtype()) - << "TypeError: mismatched types. " + TVM_FFI_CHECK(condition.defined(), ValueError) << "condition is undefined"; + TVM_FFI_CHECK(true_value.defined(), ValueError) << "true_value is undefined"; + TVM_FFI_CHECK(false_value.defined(), ValueError) << "true_value is undefined"; + TVM_FFI_ICHECK(condition.dtype().is_bool()); + TVM_FFI_ICHECK(condition.dtype().get_lanes_or_vscale_factor() == + true_value.dtype().get_lanes_or_vscale_factor() || + condition.dtype().is_scalar()); + TVM_FFI_CHECK(false_value.dtype() == true_value.dtype(), TypeError) + << "mismatched types. " << "False type: " << false_value.dtype() << "; True type: " << true_value.dtype(); ObjectPtr node = ffi::make_object(); @@ -488,10 +488,10 @@ TVM_FFI_STATIC_INIT_BLOCK() { // Ramp Ramp::Ramp(PrimExpr base, PrimExpr stride, PrimExpr lanes, Span span) { - ICHECK(base.defined()); - ICHECK(stride.defined()); - ICHECK(base.dtype().is_scalar()); - ICHECK(stride.dtype().is_scalar()); + TVM_FFI_ICHECK(base.defined()); + TVM_FFI_ICHECK(stride.defined()); + TVM_FFI_ICHECK(base.dtype().is_scalar()); + TVM_FFI_ICHECK(stride.dtype().is_scalar()); if (stride.dtype() != base.dtype()) { stride = cast(base.dtype(), stride); } @@ -500,13 +500,13 @@ Ramp::Ramp(PrimExpr base, PrimExpr stride, PrimExpr lanes, Span span) { auto* lanes_as_int = lanes.as(); if (lanes_as_int) { int lanes = static_cast(lanes_as_int->value); - ICHECK_GT(lanes, 1); + TVM_FFI_ICHECK_GT(lanes, 1); node->dtype = base.dtype().with_lanes(lanes); // Stick to int32 lanes for fixed length vectors node->lanes = lanes; } else { /* scalable vector */ std::optional vscale_factor = arith::ExtractVscaleFactor(lanes); - ICHECK(vscale_factor) << "Invalid expression for scalable lanes " << lanes; + TVM_FFI_ICHECK(vscale_factor) << "Invalid expression for scalable lanes " << lanes; node->dtype = base.dtype().with_scalable_vscale_factor(vscale_factor.value()); lanes = Mul(Call(DataType::Int(32), tir::builtin::vscale(), {}), vscale_factor.value()); @@ -527,20 +527,20 @@ TVM_FFI_STATIC_INIT_BLOCK() { // Broadcast Broadcast::Broadcast(PrimExpr value, PrimExpr lanes, Span span) { - ICHECK(value.defined()); - ICHECK(value.dtype().is_scalar()); + TVM_FFI_ICHECK(value.defined()); + TVM_FFI_ICHECK(value.dtype().is_scalar()); ObjectPtr node = ffi::make_object(); auto* lanes_int = lanes.as(); if (lanes_int) { int lanes = static_cast(lanes_int->value); - ICHECK_GT(lanes, 1); + TVM_FFI_ICHECK_GT(lanes, 1); node->dtype = value.dtype().with_lanes(lanes); // Stick to int32 lanes for fixed length vectors node->lanes = lanes; } else { /* scalable vector */ std::optional vscale_factor = arith::ExtractVscaleFactor(lanes); - ICHECK(vscale_factor) << "Invalid expression for scalable lanes " << lanes; + TVM_FFI_ICHECK(vscale_factor) << "Invalid expression for scalable lanes " << lanes; node->dtype = value.dtype().with_scalable_vscale_factor(vscale_factor.value()); lanes = Mul(Call(DataType::Int(32), tir::builtin::vscale(), {}), vscale_factor.value()); @@ -560,9 +560,9 @@ TVM_FFI_STATIC_INIT_BLOCK() { // Let Let::Let(Var var, PrimExpr value, PrimExpr body, Span span) { - ICHECK(value.defined()); - ICHECK(body.defined()); - ICHECK_EQ(value.dtype(), var.dtype()); + TVM_FFI_ICHECK(value.defined()); + TVM_FFI_ICHECK(body.defined()); + TVM_FFI_ICHECK_EQ(value.dtype(), var.dtype()); ObjectPtr node = ffi::make_object(); node->dtype = body.dtype(); @@ -583,7 +583,7 @@ TVM_FFI_STATIC_INIT_BLOCK() { // Call Call::Call(DataType dtype, RelaxExpr op, ffi::Array args, Span span) { for (size_t i = 0; i < args.size(); ++i) { - ICHECK(args[i].defined()) << "arg " << i << " is not defined()"; + TVM_FFI_ICHECK(args[i].defined()) << "arg " << i << " is not defined()"; } ObjectPtr node = ffi::make_object(); @@ -617,8 +617,8 @@ TVM_FFI_STATIC_INIT_BLOCK() { } else if (r->extent.as()) { indices.push_back(tir::Ramp(r->min, make_const(r->min->dtype, 1), r->extent)); } else { - LOG(FATAL) << "ValueError: Cannot convert to BufferLoad: " - << ffi::GetRef(br); + TVM_FFI_THROW(ValueError) + << "Cannot convert to BufferLoad: " << ffi::GetRef(br); } } prim_expr_args.push_back(BufferLoad(br->buffer, indices)); @@ -632,17 +632,17 @@ TVM_FFI_STATIC_INIT_BLOCK() { // Shuffle Shuffle::Shuffle(ffi::Array vectors, ffi::Array indices, Span span) { - ICHECK_NE(vectors.size(), 0U); - ICHECK_NE(indices.size(), 0U); + TVM_FFI_ICHECK_NE(vectors.size(), 0U); + TVM_FFI_ICHECK_NE(indices.size(), 0U); DataType base_type = vectors[0].dtype().element_of(); int total_lanes = 0; for (PrimExpr val : vectors) { - ICHECK(val.dtype().element_of() == base_type); + TVM_FFI_ICHECK(val.dtype().element_of() == base_type); total_lanes += val.dtype().lanes(); } - ICHECK_LE(indices.size(), static_cast(total_lanes)); + TVM_FFI_ICHECK_LE(indices.size(), static_cast(total_lanes)); ObjectPtr node = ffi::make_object(); node->dtype = base_type.with_lanes(static_cast(indices.size())); @@ -653,7 +653,7 @@ Shuffle::Shuffle(ffi::Array vectors, ffi::Array indices, Spa } PrimExpr Shuffle::Concat(ffi::Array vectors, Span span) { - ICHECK_NE(vectors.size(), 0); + TVM_FFI_ICHECK_NE(vectors.size(), 0); if (vectors.size() == 1) { return vectors[0]; } @@ -683,12 +683,14 @@ TVM_FFI_STATIC_INIT_BLOCK() { CommReducer::CommReducer(ffi::Array lhs, ffi::Array rhs, ffi::Array result, ffi::Array identity_element, Span span) { size_t n_group = result.size(); - CHECK_EQ(lhs.size(), n_group) << "ValueError: The number of vars in `lhs` must equal to the " - "number of elements in `results`"; - CHECK_EQ(rhs.size(), n_group) << "ValueError: The number of vars in `rhs` must equal to the " - "number of elements in `results`"; - CHECK_EQ(identity_element.size(), n_group) - << "ValueError: The number of identities must equal to the number of elements in `results`"; + TVM_FFI_CHECK_EQ(lhs.size(), n_group, ValueError) + << "The number of vars in `lhs` must equal to the " + "number of elements in `results`"; + TVM_FFI_CHECK_EQ(rhs.size(), n_group, ValueError) + << "The number of vars in `rhs` must equal to the " + "number of elements in `results`"; + TVM_FFI_CHECK_EQ(identity_element.size(), n_group, ValueError) + << "The number of identities must equal to the number of elements in `results`"; // Change the dtype of input vars to adapt to the dtype of identities ffi::ArrayObj* p_lhs = lhs.CopyOnWrite(); @@ -722,9 +724,9 @@ CommReducer::CommReducer(ffi::Array lhs, ffi::Array rhs, ffi::Array CommReducerNode::operator()(ffi::Array a, ffi::Array b) const { - ICHECK_EQ(a.size(), b.size()); - ICHECK_EQ(lhs.size(), a.size()); - ICHECK_EQ(rhs.size(), b.size()); + TVM_FFI_ICHECK_EQ(a.size(), b.size()); + TVM_FFI_ICHECK_EQ(lhs.size(), a.size()); + TVM_FFI_ICHECK_EQ(rhs.size(), b.size()); ffi::Map value_map; for (size_t i = 0; i < a.size(); ++i) { value_map.Set(lhs[i], a[i]); @@ -747,22 +749,23 @@ TVM_FFI_STATIC_INIT_BLOCK() { Reduce::Reduce(CommReducer combiner, ffi::Array source, ffi::Array axis, PrimExpr condition, int value_index, ffi::Array init, Span span) { for (size_t i = 0; i < axis.size(); ++i) { - ICHECK_EQ(axis[i]->iter_type, kCommReduce) << "Can only take axis created by reduce_axis"; + TVM_FFI_ICHECK_EQ(axis[i]->iter_type, kCommReduce) + << "Can only take axis created by reduce_axis"; } if (!condition.defined()) { condition = const_true(); } auto n = ffi::make_object(); - ICHECK(source.defined()); + TVM_FFI_ICHECK(source.defined()); for (size_t i = 0; i < axis.size(); ++i) { - ICHECK(axis[i].defined()); + TVM_FFI_ICHECK(axis[i].defined()); } if (!init.empty()) { - ICHECK_EQ(init.size(), source.size()) << "Number of inits should match number of exprs"; + TVM_FFI_ICHECK_EQ(init.size(), source.size()) << "Number of inits should match number of exprs"; for (size_t i = 0; i < init.size(); i++) { - ICHECK(init[i].defined()) << "Init value must be defined"; - ICHECK(init[i]->IsInstance() || init[i]->IsInstance() || - init[i]->IsInstance()) + TVM_FFI_ICHECK(init[i].defined()) << "Init value must be defined"; + TVM_FFI_ICHECK(init[i]->IsInstance() || init[i]->IsInstance() || + init[i]->IsInstance()) << "init can only be a IntImm, FloatImm or ProducerLoad, " << "but received " << init[i] << " of type " << init[i]->GetTypeKey(); } @@ -790,7 +793,7 @@ TVM_FFI_STATIC_INIT_BLOCK() { // BufferLoad void BufferLoadNode::LegalizeDType() { for (int i = 0; i < static_cast(indices.size()) - 1; i++) { - ICHECK(indices[i].dtype().is_scalar()) + TVM_FFI_ICHECK(indices[i].dtype().is_scalar()) << "Only the last index of a buffer access may be a vector type."; } @@ -801,7 +804,7 @@ void BufferLoadNode::LegalizeDType() { bool is_buffer_dtype_scalable = buffer->dtype.is_scalable_vector(); bool is_index_scalable = index_dtype.is_scalable_vector(); - ICHECK(!(is_index_scalable && is_buffer_dtype_scalable)) + TVM_FFI_ICHECK(!(is_index_scalable && is_buffer_dtype_scalable)) << "Index dtype and buffer dtype can't both be scalable."; if (is_index_scalable) { @@ -818,7 +821,7 @@ void BufferLoadNode::LegalizeDType() { BufferLoad::BufferLoad(Buffer buffer, ffi::Array indices, ffi::Optional predicate, Span span) { - ICHECK_EQ(buffer->shape.size(), indices.size()) + TVM_FFI_ICHECK_EQ(buffer->shape.size(), indices.size()) << "Buffer " << buffer->name << " is " << buffer->shape.size() << "-dimensional, cannot be indexed with the " << indices.size() << "-dimensional indices provided."; @@ -828,19 +831,19 @@ BufferLoad::BufferLoad(Buffer buffer, ffi::Array indices, bool is_index_scalable = indices.empty() ? false : indices.back().dtype().is_scalable_vector(); bool is_predicate_scalable = predicate_dtype.is_scalable_vector(); - ICHECK_EQ(is_index_scalable, is_predicate_scalable) + TVM_FFI_ICHECK_EQ(is_index_scalable, is_predicate_scalable) << "Predicate mask dtype and load indices must both be scalable."; int buffer_lanes = buffer->dtype.get_lanes_or_vscale_factor(); int index_lanes = indices.empty() ? 1 : indices.back().dtype().get_lanes_or_vscale_factor(); int predicate_lanes = predicate_dtype.get_lanes_or_vscale_factor(); - ICHECK_EQ(index_lanes * buffer_lanes, predicate_lanes) + TVM_FFI_ICHECK_EQ(index_lanes * buffer_lanes, predicate_lanes) << "Got a predicate mask with " << predicate_lanes << " lanes, but trying to load a vector with " << index_lanes << " lanes. The number of lanes must match."; DataType predicate_element_dtype = predicate_dtype.element_of(); - ICHECK(predicate_element_dtype.is_predicate_dtype()) + TVM_FFI_ICHECK(predicate_element_dtype.is_predicate_dtype()) << "Predicate mask elements must be boolean values, but got " << predicate_element_dtype << "."; } diff --git a/src/tir/ir/function.cc b/src/tir/ir/function.cc index e5ced89425a2..1516bd8aa00e 100644 --- a/src/tir/ir/function.cc +++ b/src/tir/ir/function.cc @@ -118,16 +118,18 @@ class TensorIntrinManager { TensorIntrin::TensorIntrin(PrimFunc desc, PrimFunc impl) { // Check the number of func var is equal - CHECK_EQ(desc->params.size(), impl->params.size()) - << "ValueError: The number of parameters of the description and the implementation of the " + TVM_FFI_CHECK_EQ(desc->params.size(), impl->params.size(), ValueError) + << "The number of parameters of the description and the implementation of the " "tensor intrinsic doesn't match."; for (size_t i = 0; i < desc->params.size(); i++) { - CHECK(desc->params[i]->dtype.is_handle()) << "ValueError: Parameters of the description of the " - "tensor intrinsic should be handle only."; - CHECK(impl->params[i]->dtype.is_handle()) << "ValueError: Parameters of the implementation of " - "the tensor intrinsic should be handle only."; + TVM_FFI_CHECK(desc->params[i]->dtype.is_handle(), ValueError) + << "Parameters of the description of the " + "tensor intrinsic should be handle only."; + TVM_FFI_CHECK(impl->params[i]->dtype.is_handle(), ValueError) + << "Parameters of the implementation of " + "the tensor intrinsic should be handle only."; } - ICHECK_EQ(desc->buffer_map.size(), impl->buffer_map.size()); + TVM_FFI_ICHECK_EQ(desc->buffer_map.size(), impl->buffer_map.size()); ObjectPtr n = ffi::make_object(); n->desc = std::move(desc); @@ -138,8 +140,8 @@ TensorIntrin::TensorIntrin(PrimFunc desc, PrimFunc impl) { void TensorIntrin::Register(ffi::String name, TensorIntrin intrin, bool override) { TensorIntrinManager* manager = TensorIntrinManager::Global(); if (!override) { - CHECK_EQ(manager->reg.count(name), 0) - << "ValueError: TensorIntrin '" << name << "' has already been registered"; + TVM_FFI_CHECK_EQ(manager->reg.count(name), 0, ValueError) + << "TensorIntrin '" << name << "' has already been registered"; } manager->reg.Set(name, intrin); } @@ -151,7 +153,7 @@ ffi::Optional TensorIntrin::Get(ffi::String name, bool allow_missi if (allow_missing) { return std::nullopt; } else { - LOG(FATAL) << "ValueError: TensorIntrin '" << name << "' is not registered"; + TVM_FFI_THROW(ValueError) << "TensorIntrin '" << name << "' is not registered"; } } return (*it).second; diff --git a/src/tir/ir/index_map.cc b/src/tir/ir/index_map.cc index cdd1d8ad56d8..1c48d5fd1cc3 100644 --- a/src/tir/ir/index_map.cc +++ b/src/tir/ir/index_map.cc @@ -61,7 +61,7 @@ std::pair IndexMapInverseImpl(const IndexMap& self, const ffi::Array& initial_ranges, arith::IterMapLevel check_level, arith::Analyzer* analyzer) { - ICHECK(analyzer != nullptr); + TVM_FFI_ICHECK(analyzer != nullptr); if (self->inverse_index_map.defined()) { // return the pre-defined inverse index map if exists. In this // case, the user-defined inverse is assumed to be correct and @@ -87,7 +87,7 @@ std::pair IndexMapInverseImpl(const IndexMap& self, // Dummy ranges for the extent of each input. ffi::Map input_iters; - ICHECK_EQ(self->initial_indices.size(), initial_ranges.size()); + TVM_FFI_ICHECK_EQ(self->initial_indices.size(), initial_ranges.size()); for (size_t i = 0; i < initial_ranges.size(); i++) { input_iters.Set(self->initial_indices[i], initial_ranges[i]); } @@ -97,8 +97,9 @@ std::pair IndexMapInverseImpl(const IndexMap& self, auto padded_iter_map = DetectIterMap(self->final_indices, input_iters, /*predicate=*/1, /*check_level=*/check_level, analyzer, /*simplify_trivial_iterators=*/false); - CHECK(padded_iter_map->errors.empty()) << "Could not parse mapping as sum of iterators. " - << "Error: " << padded_iter_map->errors[0]; + TVM_FFI_ICHECK(padded_iter_map->errors.empty()) + << "Could not parse mapping as sum of iterators. " + << "Error: " << padded_iter_map->errors[0]; // Determine expressions for the input variables, in terms of the // output variables. @@ -124,7 +125,7 @@ std::pair IndexMapInverseImpl(const IndexMap& self, auto output_ranges = self->MapRanges(initial_ranges, analyzer); { - ICHECK_EQ(output_ranges.size(), output_vars.size()); + TVM_FFI_ICHECK_EQ(output_ranges.size(), output_vars.size()); arith::Analyzer analyzer; for (size_t i = 0; i < output_vars.size(); ++i) { @@ -140,15 +141,15 @@ std::pair IndexMapInverseImpl(const IndexMap& self, std::pair IndexMap::NonSurjectiveInverse(ffi::Array initial_ranges, arith::Analyzer* analyzer) const { - ICHECK(analyzer != nullptr); + TVM_FFI_ICHECK(analyzer != nullptr); return IndexMapInverseImpl(*this, initial_ranges, arith::IterMapLevel::NoCheck, analyzer); } IndexMap IndexMap::Inverse(ffi::Array initial_ranges, arith::Analyzer* analyzer) const { - ICHECK(analyzer != nullptr); + TVM_FFI_ICHECK(analyzer != nullptr); auto [inverse, padding_predicate] = IndexMapInverseImpl(*this, initial_ranges, arith::IterMapLevel::Bijective, analyzer); - CHECK(analyzer->CanProve(!padding_predicate)) + TVM_FFI_ICHECK(analyzer->CanProve(!padding_predicate)) << "Bijective inverse should not contain padding, but inverse of " << *this << " over range " << initial_ranges << " resulted in a padding predicate of " << padding_predicate; return inverse; @@ -156,8 +157,8 @@ IndexMap IndexMap::Inverse(ffi::Array initial_ranges, arith::Analyzer* an ffi::Array IndexMapNode::MapIndices(const ffi::Array& indices, arith::Analyzer* analyzer) const { - ICHECK(analyzer != nullptr); - ICHECK_EQ(indices.size(), initial_indices.size()); + TVM_FFI_ICHECK(analyzer != nullptr); + TVM_FFI_ICHECK_EQ(indices.size(), initial_indices.size()); ffi::Map vmap; @@ -175,8 +176,8 @@ ffi::Array IndexMapNode::MapIndices(const ffi::Array& indice ffi::Array IndexMapNode::MapRanges(const ffi::Array& ranges, arith::Analyzer* analyzer) const { - ICHECK(analyzer != nullptr); - ICHECK_EQ(ranges.size(), initial_indices.size()); + TVM_FFI_ICHECK(analyzer != nullptr); + TVM_FFI_ICHECK_EQ(ranges.size(), initial_indices.size()); ffi::Map input_iters; for (size_t i = 0; i < initial_indices.size(); i++) { @@ -239,8 +240,8 @@ ffi::Array IndexMapNode::MapRanges(const ffi::Array& ranges, ffi::Array IndexMapNode::MapShape(const ffi::Array& shape, arith::Analyzer* analyzer) const { - ICHECK(analyzer != nullptr); - ICHECK_EQ(shape.size(), initial_indices.size()); + TVM_FFI_ICHECK(analyzer != nullptr); + TVM_FFI_ICHECK_EQ(shape.size(), initial_indices.size()); ffi::Array ranges; for (auto& dim : shape) { @@ -250,7 +251,7 @@ ffi::Array IndexMapNode::MapShape(const ffi::Array& shape, ffi::Array output; for (auto& range : mapped) { - ICHECK(is_zero(range->min)); + TVM_FFI_ICHECK(is_zero(range->min)); output.push_back(range->extent); } @@ -260,7 +261,7 @@ ffi::Array IndexMapNode::MapShape(const ffi::Array& shape, runtime::Tensor IndexMapNode::MapTensor(runtime::Tensor arr_src) const { arith::Analyzer analyzer; auto shape = arr_src.Shape(); - ICHECK(shape.size() == initial_indices.size()) + TVM_FFI_ICHECK(shape.size() == initial_indices.size()) << "The rank of the input array should be " << initial_indices.size() << " but got " << shape.size(); size_t size_1d = 1; @@ -333,7 +334,7 @@ IndexMap IndexMap::RenameVariables( Var var = Downcast(obj); if (ffi::Optional opt_name = f_name_map(var); opt_name.has_value()) { ffi::String name = opt_name.value(); - ICHECK(!name_supply->ContainsName(name, /*add_prefix=*/false)); + TVM_FFI_ICHECK(!name_supply->ContainsName(name, /*add_prefix=*/false)); name_supply->ReserveName(name, /*add_prefix=*/false); var_remap.Set(var, Var(name, var->dtype)); } diff --git a/src/tir/ir/script/script_complete.cc b/src/tir/ir/script/script_complete.cc index 4c2ccab58e10..1dc2e7bdff2c 100644 --- a/src/tir/ir/script/script_complete.cc +++ b/src/tir/ir/script/script_complete.cc @@ -43,7 +43,7 @@ class ScriptCompleter : public StmtMutator { ffi::Map* buffer_var_map_; Stmt VisitStmt_(const SBlockRealizeNode* op) final { for (const PrimExpr& value : op->iter_values) { - CHECK(value.dtype().is_int()) + TVM_FFI_ICHECK(value.dtype().is_int()) << "BlockRealize iter_value expected a IntImm, but got " << value.dtype(); } return StmtMutator::VisitStmt_(op); @@ -85,8 +85,8 @@ class ScriptCompleter : public StmtMutator { const ffi::Array& reads = access_region[0]; const ffi::Array& writes = access_region[1]; const ffi::Array& opaque = access_region[2]; - CHECK(opaque.empty()) - << "ValueError: Can not auto detect buffer access region from tir.Load, tir.Store or " + TVM_FFI_CHECK(opaque.empty(), ValueError) + << "Can not auto detect buffer access region from tir.Load, tir.Store or " "direct access by buffer data. Please annotation the access region manually"; auto n = CopyOnWrite(block.operator->()); if (!is_root_block) { diff --git a/src/tir/ir/specialize.cc b/src/tir/ir/specialize.cc index 1de16efd2ac6..6c4f4666ef90 100644 --- a/src/tir/ir/specialize.cc +++ b/src/tir/ir/specialize.cc @@ -121,7 +121,7 @@ class PrimFuncSpecializer : public StmtExprMutator { // Step.1. Recursively visit block body Stmt stmt = StmtExprMutator::VisitStmt_(op); op = stmt.as(); - ICHECK(op != nullptr); + TVM_FFI_ICHECK(op != nullptr); ffi::Array reads = op->reads.Map([this](const auto& region) { return MutateBufferRegion(region); }); @@ -180,7 +180,7 @@ class PrimFuncSpecializer : public StmtExprMutator { Stmt VisitStmt_(const BufferStoreNode* op) final { Stmt stmt = StmtExprMutator::VisitStmt_(op); op = stmt.as(); - ICHECK(op != nullptr); + TVM_FFI_ICHECK(op != nullptr); auto new_buf = GetNewBuffer(op->buffer); if (new_buf.same_as(op->buffer)) { @@ -195,7 +195,7 @@ class PrimFuncSpecializer : public StmtExprMutator { PrimExpr VisitExpr_(const BufferLoadNode* op) final { PrimExpr expr = StmtExprMutator::VisitExpr_(op); op = expr.as(); - ICHECK(op != nullptr); + TVM_FFI_ICHECK(op != nullptr); auto new_buf = GetNewBuffer(op->buffer); if (new_buf.same_as(op->buffer)) { @@ -273,7 +273,7 @@ class PrimFuncSpecializer : public StmtExprMutator { } Buffer MutateAllocBuffer(const Buffer& alloc_buf) { - ICHECK(!buffer_map_.count(alloc_buf)) + TVM_FFI_ICHECK(!buffer_map_.count(alloc_buf)) << "Multiple points of definition found for buffer " << alloc_buf; Buffer buf = MutateBuffer(alloc_buf); @@ -287,7 +287,7 @@ class PrimFuncSpecializer : public StmtExprMutator { } auto mutated = MutateBuffer(old_buffer); - ICHECK(mutated.same_as(old_buffer)) + TVM_FFI_ICHECK(mutated.same_as(old_buffer)) << "Buffer " << old_buffer << " (shape = " << old_buffer->shape << ")" << " was used without a declaration, " << "and would be specialized into " << mutated << " (shape = " << mutated->shape << "). " @@ -344,22 +344,22 @@ void UpdateSpecializeVarMap(const PrimFunc& func, const Var& param, const Buffer tir::ExprDeepEqual equal; auto it = func->buffer_map.find(param); - CHECK(it != func->buffer_map.end()) - << "ValueError: specialize expects param to be in PrimFunc's buffer_map"; + TVM_FFI_CHECK(it != func->buffer_map.end(), ValueError) + << "specialize expects param to be in PrimFunc's buffer_map"; const Buffer& buf_to_specialize = (*it).second; // build var mapping using specific_buf's parameters auto build_var_mapping = [&](const PrimExpr& new_expr, const PrimExpr& old_expr) { if (!equal(new_expr, old_expr)) { - CHECK(old_expr->IsInstance()) - << "TypeError: The signature of target buffer exprected an independent Var, but got " - << old_expr << "."; + TVM_FFI_CHECK(old_expr->IsInstance(), TypeError) + << "The signature of target buffer exprected an independent Var, but got " << old_expr + << "."; const Var& var = Downcast(old_expr); auto it = var_map->find(var); if (it != var_map->end()) { - CHECK(equal(it->second, new_expr)) - << "ValueError: The assigned value of var " << var << " mismatched. " << it->second - << " vs. " << new_expr << "."; + TVM_FFI_CHECK(equal(it->second, new_expr), ValueError) + << "The assigned value of var " << var << " mismatched. " << it->second << " vs. " + << new_expr << "."; } else { (*var_map)[var] = new_expr; } @@ -367,13 +367,13 @@ void UpdateSpecializeVarMap(const PrimFunc& func, const Var& param, const Buffer }; // Check buffer dimensions - CHECK(specific_buf->shape.size() == buf_to_specialize->shape.size()) - << "ValueError: The buffer dimensions mismatched" << buf_to_specialize->shape.size() - << " vs. " << specific_buf->shape.size() << "."; + TVM_FFI_CHECK(specific_buf->shape.size() == buf_to_specialize->shape.size(), ValueError) + << "The buffer dimensions mismatched" << buf_to_specialize->shape.size() << " vs. " + << specific_buf->shape.size() << "."; - CHECK(specific_buf->strides.size() == buf_to_specialize->strides.size()) - << "ValueError: The buffer strides dimensions mismatched" << buf_to_specialize->strides.size() - << " vs. " << specific_buf->strides.size() << "."; + TVM_FFI_CHECK(specific_buf->strides.size() == buf_to_specialize->strides.size(), ValueError) + << "The buffer strides dimensions mismatched" << buf_to_specialize->strides.size() << " vs. " + << specific_buf->strides.size() << "."; // Updating var mapping using specific_expr build_var_mapping(specific_buf->data, buf_to_specialize->data); @@ -387,13 +387,13 @@ void UpdateSpecializeVarMap(const PrimFunc& func, const Var& param, const Buffer // Check data_alignment and offset_factor. // These two signatures are int, so we do not need map them. - CHECK_EQ(specific_buf->data_alignment, buf_to_specialize->data_alignment) - << "ValueError: The buffer data_alignment mismatched" << buf_to_specialize->data_alignment - << " vs. " << specific_buf->data_alignment << "."; + TVM_FFI_CHECK_EQ(specific_buf->data_alignment, buf_to_specialize->data_alignment, ValueError) + << "The buffer data_alignment mismatched" << buf_to_specialize->data_alignment << " vs. " + << specific_buf->data_alignment << "."; - CHECK_EQ(specific_buf->offset_factor, buf_to_specialize->offset_factor) - << "ValueError: The buffer offset_factor mismatched" << buf_to_specialize->offset_factor - << " vs. " << specific_buf->offset_factor << "."; + TVM_FFI_CHECK_EQ(specific_buf->offset_factor, buf_to_specialize->offset_factor, ValueError) + << "The buffer offset_factor mismatched" << buf_to_specialize->offset_factor << " vs. " + << specific_buf->offset_factor << "."; } /*! @@ -406,10 +406,11 @@ void UpdateSpecializeVarMap(const PrimFunc& func, const Var& param, const Buffer void UpdateSpecializeVarMap(const PrimFunc& func, const Var& param, const PrimExpr& specific_expr, VarMap* var_map) { // check param is in PrimFunc's parameters - CHECK(IsParam(func, param)) << "ValueError: Specialize expects param to be in PrimFunc's params"; + TVM_FFI_CHECK(IsParam(func, param), ValueError) + << "Specialize expects param to be in PrimFunc's params"; // specialize a param not in buffer_map - CHECK_EQ(func->buffer_map.count(param), 0) - << "ValueError: Specialize expects param to not be in PrimFunc's buffer_map"; + TVM_FFI_CHECK_EQ(func->buffer_map.count(param), 0, ValueError) + << "Specialize expects param to not be in PrimFunc's buffer_map"; // build var mapping using specific_expr (*var_map)[param] = specific_expr; } @@ -426,7 +427,7 @@ PrimFunc Specialize(PrimFunc func, const ffi::Map()) { UpdateSpecializeVarMap(func, param, opt_expr.value(), &var_map); } else { - LOG(FATAL) << "TypeError: specialize expected instance to be Buffer or PrimExpr"; + TVM_FFI_THROW(TypeError) << "specialize expected instance to be Buffer or PrimExpr"; } } return PrimFuncSpecializer::Specialize(func, std::move(var_map)); diff --git a/src/tir/ir/stmt.cc b/src/tir/ir/stmt.cc index 6f6b2f7e149a..4f0dbaf12111 100644 --- a/src/tir/ir/stmt.cc +++ b/src/tir/ir/stmt.cc @@ -53,15 +53,15 @@ TVM_FFI_STATIC_INIT_BLOCK() { // LetStmt LetStmt::LetStmt(Var var, PrimExpr value, Stmt body, Span span) { - ICHECK(value.defined()); - ICHECK(body.defined()); + TVM_FFI_ICHECK(value.defined()); + TVM_FFI_ICHECK(body.defined()); auto vdtype = value.dtype(); // It is still valid to bind a pointer type // var to a value that is of type handle. if (var->type_annotation.as()) { - ICHECK(vdtype.is_handle()); + TVM_FFI_ICHECK(vdtype.is_handle()); } else { - ICHECK_EQ(value.dtype(), var.dtype()); + TVM_FFI_ICHECK_EQ(value.dtype(), var.dtype()); } ObjectPtr node = ffi::make_object(); @@ -105,12 +105,12 @@ TVM_FFI_STATIC_INIT_BLOCK() { // AssertStmt AssertStmt::AssertStmt(PrimExpr condition, PrimExpr message, Stmt body, Span span) { - ICHECK(condition.defined()); - CHECK(condition.dtype().is_bool()) + TVM_FFI_ICHECK(condition.defined()); + TVM_FFI_ICHECK(condition.dtype().is_bool()) << "AssertStmt should have boolean condition, " << "but received " << condition << " with dtype " << condition.dtype(); - ICHECK(message.dtype() == DataType::Int(32) || message.as()) - << "TypeError: AssertStmt message must be an int or string:" << message << "\n"; + TVM_FFI_CHECK(message.dtype() == DataType::Int(32) || message.as(), TypeError) + << "AssertStmt message must be an int or string:" << message << "\n"; ObjectPtr node = ffi::make_object(); node->condition = std::move(condition); @@ -132,14 +132,14 @@ TVM_FFI_STATIC_INIT_BLOCK() { For::For(Var loop_var, PrimExpr min, PrimExpr extent, ForKind kind, Stmt body, ffi::Optional thread_binding, ffi::Map annotations, ffi::Optional step, Span span) { - ICHECK(loop_var.defined()); - ICHECK(min.defined()); - ICHECK(extent.defined()); - ICHECK(body.defined()); + TVM_FFI_ICHECK(loop_var.defined()); + TVM_FFI_ICHECK(min.defined()); + TVM_FFI_ICHECK(extent.defined()); + TVM_FFI_ICHECK(body.defined()); auto require_scalar_int_dtype = [&](PrimExpr expr, const char* field_name) { auto dtype = expr.dtype(); - CHECK(dtype.is_scalar() && (dtype.is_int() || dtype.is_uint())) + TVM_FFI_ICHECK(dtype.is_scalar() && (dtype.is_int() || dtype.is_uint())) << "TIR For nodes require a scalar integer as the " << field_name << ", but received " << expr << " with dtype " << dtype; }; @@ -150,7 +150,7 @@ For::For(Var loop_var, PrimExpr min, PrimExpr extent, ForKind kind, Stmt body, // When extent, min or step is an IntImm but has narrower dtype than loop_var // we directly promote them without raising errors. auto try_promote_imm_dtype = [&](const PrimExpr& e) { - ICHECK(e.dtype().bits() <= loop_var.dtype().bits()) + TVM_FFI_ICHECK(e.dtype().bits() <= loop_var.dtype().bits()) << " Loop variable's dtype (" << loop_var.dtype() << ") is narrower than that of `min` or `extent` (" << e.dtype() << ")"; const IntImmNode* a = e.as(); @@ -164,13 +164,15 @@ For::For(Var loop_var, PrimExpr min, PrimExpr extent, ForKind kind, Stmt body, min = try_promote_imm_dtype(min); extent = try_promote_imm_dtype(extent); - ICHECK(loop_var.dtype() == min.dtype()) << loop_var.dtype() << " vs " << min.dtype(); - ICHECK(loop_var.dtype() == extent.dtype()) << loop_var.dtype() << " vs " << extent.dtype(); + TVM_FFI_ICHECK(loop_var.dtype() == min.dtype()) << loop_var.dtype() << " vs " << min.dtype(); + TVM_FFI_ICHECK(loop_var.dtype() == extent.dtype()) + << loop_var.dtype() << " vs " << extent.dtype(); if (step.has_value()) { require_scalar_int_dtype(*step, "step"); step = try_promote_imm_dtype(*step); - ICHECK(loop_var.dtype() == (*step).dtype()) << loop_var.dtype() << " vs " << (*step).dtype(); + TVM_FFI_ICHECK(loop_var.dtype() == (*step).dtype()) + << loop_var.dtype() << " vs " << (*step).dtype(); } ObjectPtr node = ffi::make_object(); @@ -222,9 +224,9 @@ std::ostream& operator<<(std::ostream& out, ForKind type) { // NOLINT(*) // While While::While(PrimExpr condition, Stmt body, Span span) { - ICHECK(condition.defined()); - ICHECK(condition.dtype().is_scalar()); - ICHECK(body.defined()); + TVM_FFI_ICHECK(condition.defined()); + TVM_FFI_ICHECK(condition.dtype().is_scalar()); + TVM_FFI_ICHECK(body.defined()); ObjectPtr node = ffi::make_object(); node->condition = std::move(condition); @@ -243,20 +245,20 @@ TVM_FFI_STATIC_INIT_BLOCK() { // Allocate Allocate::Allocate(Var buffer_var, DataType dtype, ffi::Array extents, PrimExpr condition, Stmt body, ffi::Map annotations, Span span) { - CHECK(IsPointerType(buffer_var->type_annotation, dtype) || - (dtype.is_bool() && IsPointerType(buffer_var->type_annotation, DataType::Int(8)))) + TVM_FFI_ICHECK(IsPointerType(buffer_var->type_annotation, dtype) || + (dtype.is_bool() && IsPointerType(buffer_var->type_annotation, DataType::Int(8)))) << "The allocated data type (" << dtype << ") does not match the type annotation of the buffer " << buffer_var << " (" << buffer_var->type_annotation << "). The data type should be an element of the pointer type."; for (size_t i = 0; i < extents.size(); ++i) { - ICHECK(extents[i].defined()); - ICHECK(extents[i].dtype().is_scalar()); + TVM_FFI_ICHECK(extents[i].defined()); + TVM_FFI_ICHECK(extents[i].dtype().is_scalar()); } - ICHECK(body.defined()); - ICHECK(condition.defined()); - ICHECK(condition.dtype().is_bool()); + TVM_FFI_ICHECK(body.defined()); + TVM_FFI_ICHECK(condition.defined()); + TVM_FFI_ICHECK(condition.dtype().is_bool()); ObjectPtr node = ffi::make_object(); node->buffer_var = std::move(buffer_var); @@ -324,12 +326,12 @@ SeqStmt::SeqStmt(ffi::Array seq, Span span) { } } - ICHECK_NE(seq.size(), 0) << "An empty SeqStmt is prohibited. " - << "To write a no-op, use Evaluate(0), " - << "or the result of SeqStmt::Flatten()"; - ICHECK_NE(seq.size(), 1) << "A SeqStmt of length 1 is prohibited. " - << "Use the node " << seq[0] << "directly, " - << "or for dynamic usage, normalize using SeqStmt::Flatten()"; + TVM_FFI_ICHECK_NE(seq.size(), 0) << "An empty SeqStmt is prohibited. " + << "To write a no-op, use Evaluate(0), " + << "or the result of SeqStmt::Flatten()"; + TVM_FFI_ICHECK_NE(seq.size(), 1) << "A SeqStmt of length 1 is prohibited. " + << "Use the node " << seq[0] << "directly, " + << "or for dynamic usage, normalize using SeqStmt::Flatten()"; auto node = ffi::make_object(); node->seq = std::move(seq); @@ -346,8 +348,8 @@ TVM_FFI_STATIC_INIT_BLOCK() { // IfThenElse IfThenElse::IfThenElse(PrimExpr condition, Stmt then_case, ffi::Optional else_case, Span span) { - ICHECK(condition.defined()); - ICHECK(then_case.defined()); + TVM_FFI_ICHECK(condition.defined()); + TVM_FFI_ICHECK(then_case.defined()); // else_case may be null. ObjectPtr node = ffi::make_object(); node->condition = std::move(condition); @@ -367,7 +369,7 @@ TVM_FFI_STATIC_INIT_BLOCK() { // Evaluate Evaluate::Evaluate(PrimExpr value, Span span) { - ICHECK(value.defined()); + TVM_FFI_ICHECK(value.defined()); ObjectPtr node = ffi::make_object(); node->value = std::move(value); @@ -384,13 +386,13 @@ TVM_FFI_STATIC_INIT_BLOCK() { // BufferStore BufferStore::BufferStore(Buffer buffer, PrimExpr value, ffi::Array indices, ffi::Optional predicate, Span span) { - ICHECK_EQ(buffer->shape.size(), indices.size()) + TVM_FFI_ICHECK_EQ(buffer->shape.size(), indices.size()) << "Buffer " << buffer->name << " is " << buffer->shape.size() << "-dimensional, cannot be indexed with the " << indices.size() << "-dimensional indices provided."; for (int i = 0; i < static_cast(indices.size()) - 1; i++) { - ICHECK(indices[i].dtype().is_scalar()) + TVM_FFI_ICHECK(indices[i].dtype().is_scalar()) << "Only the last index of a buffer access may be a vector type."; } @@ -398,24 +400,24 @@ BufferStore::BufferStore(Buffer buffer, PrimExpr value, ffi::Array ind bool is_buffer_dtype_scalable = buffer->dtype.is_scalable_vector(); bool is_value_dtype_scalable = value.dtype().is_scalable_vector(); - ICHECK(!(is_index_scalable && is_buffer_dtype_scalable)) + TVM_FFI_ICHECK(!(is_index_scalable && is_buffer_dtype_scalable)) << "Index dtype and buffer dtype can't both be scalable."; if (predicate.defined()) { bool is_predicate_dtype_scalable = predicate.value().dtype().is_scalable_vector(); - ICHECK_EQ(is_value_dtype_scalable, is_predicate_dtype_scalable) + TVM_FFI_ICHECK_EQ(is_value_dtype_scalable, is_predicate_dtype_scalable) << "Predicate mask dtype and value dtype must both be scalable."; } if (is_index_scalable || is_buffer_dtype_scalable) { - ICHECK(is_value_dtype_scalable) << "Can't store non-scalable data into scalable buffer"; + TVM_FFI_ICHECK(is_value_dtype_scalable) << "Can't store non-scalable data into scalable buffer"; } int index_lanes = indices.empty() ? 1 : indices.back().dtype().get_lanes_or_vscale_factor(); int buffer_lanes = buffer->dtype.get_lanes_or_vscale_factor(); int value_dtype_lanes = value.dtype().get_lanes_or_vscale_factor(); - ICHECK_EQ(index_lanes * buffer_lanes, value_dtype_lanes) + TVM_FFI_ICHECK_EQ(index_lanes * buffer_lanes, value_dtype_lanes) << "Cannot store value with " << value_dtype_lanes << ", expected value with " << index_lanes * buffer_lanes << " (" << index_lanes << " index lanes * " << buffer_lanes << " buffer element lanes)"; @@ -423,13 +425,13 @@ BufferStore::BufferStore(Buffer buffer, PrimExpr value, ffi::Array ind if (predicate.defined()) { DataType predicate_dtype = predicate.value().dtype(); int predicate_dtype_lanes = predicate_dtype.get_lanes_or_vscale_factor(); - ICHECK_EQ(value_dtype_lanes, predicate_dtype_lanes) + TVM_FFI_ICHECK_EQ(value_dtype_lanes, predicate_dtype_lanes) << "Got a predicate mask with " << predicate_dtype_lanes << " lanes, but trying to store a value with " << value_dtype_lanes << " lanes. The number of lanes must match."; DataType predicate_element_dtype = predicate_dtype.element_of(); - ICHECK(predicate_element_dtype.is_predicate_dtype()) + TVM_FFI_ICHECK(predicate_element_dtype.is_predicate_dtype()) << "Predicate mask elements must be boolean values, but got " << predicate_element_dtype << "."; } @@ -441,11 +443,11 @@ BufferStore::BufferStore(Buffer buffer, PrimExpr value, ffi::Array ind buffer_dtype = buffer->dtype.with_lanes(buffer_lanes * index_lanes); } if (buffer_dtype != value.dtype()) { - LOG(FATAL) << "TypeError: dtype mismatch on BufferStore: " // - << "buffer's dtype is `" << buffer->dtype // - << "`, the lanes of indexing are: `" << index_lanes // - << "`, the scalability is: `" << buffer_dtype.is_scalable_vector() - << "`, but RHS's dtype is `" << value.dtype() << "`"; + TVM_FFI_THROW(TypeError) << "dtype mismatch on BufferStore: " // + << "buffer's dtype is `" << buffer->dtype // + << "`, the lanes of indexing are: `" << index_lanes // + << "`, the scalability is: `" << buffer_dtype.is_scalable_vector() + << "`, but RHS's dtype is `" << value.dtype() << "`"; } ObjectPtr node = ffi::make_object(); @@ -477,14 +479,15 @@ PrimExpr BufferRegionNode::ToPrimExpr() const { } else if (r->extent.as()) { indices.push_back(tir::Ramp(r->min, tvm::tir::make_const(r->min->dtype, 1), r->extent)); } else { - LOG(FATAL) << "ValueError: Cannot convert to BufferLoad: " << ffi::GetRef(this); + TVM_FFI_THROW(ValueError) << "Cannot convert to BufferLoad: " + << ffi::GetRef(this); } } return tir::BufferLoad(this->buffer, indices); } BufferRegion::BufferRegion(Buffer buffer, ffi::Array region) { - CHECK_EQ(buffer->shape.size(), region.size()) + TVM_FFI_ICHECK_EQ(buffer->shape.size(), region.size()) << "The dimension between " << buffer << " and region " << region << " mismatched, the buffer is " << buffer; ObjectPtr node = ffi::make_object(); @@ -526,39 +529,39 @@ MatchBufferRegion::MatchBufferRegion(Buffer buffer, BufferRegion source) { const Buffer& source_buffer = source->buffer; arith::Analyzer analyzer; // Check scope and dtype - CHECK_EQ(buffer.scope(), source_buffer.scope()) + TVM_FFI_ICHECK_EQ(buffer.scope(), source_buffer.scope()) << "MatchBuffer " << buffer << " scope mismatch:" << buffer.scope() << " vs. " << source_buffer.scope(); - CHECK_EQ(buffer->dtype, source_buffer->dtype) + TVM_FFI_ICHECK_EQ(buffer->dtype, source_buffer->dtype) << "MatchBuffer " << buffer << " data type mismatch:" << buffer->dtype << " vs. " << source_buffer->dtype; // Check data_alignment - CHECK(source_buffer->data_alignment % buffer->data_alignment == 0) + TVM_FFI_ICHECK(source_buffer->data_alignment % buffer->data_alignment == 0) << "Trying to match buffer to another one with lower alignment requirement " << " required alignment=" << buffer->data_alignment << ", provided alignment=" << source_buffer->data_alignment; // Check BufferType. AutoBroadcast is not allowed for now. - CHECK(buffer->buffer_type == BufferType::kDefault && - source_buffer->buffer_type == BufferType::kDefault) + TVM_FFI_ICHECK(buffer->buffer_type == BufferType::kDefault && + source_buffer->buffer_type == BufferType::kDefault) << "AutoBroadcast is not allowed in MatchBuffer"; // Validate shape - CHECK(source->region.size() >= buffer->shape.size()) + TVM_FFI_ICHECK(source->region.size() >= buffer->shape.size()) << "Dimension of source Region expected to be larger or equal than target buffer shape, but " "got " << source->region.size() << " vs. " << buffer->shape.size(); size_t offset = source->region.size() - buffer->shape.size(); for (size_t i = 0; i < offset; ++i) { - CHECK(analyzer.CanProve(source->region[i]->extent == 1)) + TVM_FFI_ICHECK(analyzer.CanProve(source->region[i]->extent == 1)) << "The higher dimension should be 1, but got " << source->region[i]->extent << "."; } for (size_t i = 0; i < buffer->shape.size(); ++i) { const Range& source_range = source->region[i + offset]; const PrimExpr& buffer_shape = buffer->shape[i]; if (!buffer_shape->IsInstance()) { - CHECK(analyzer.CanProve(source_range->extent == buffer_shape)) + TVM_FFI_ICHECK(analyzer.CanProve(source_range->extent == buffer_shape)) << "The dimension mismatched between source region and target buffer shape, got " << source_range->extent << " vs. " << buffer_shape << "."; } @@ -615,10 +618,10 @@ TVM_FFI_STATIC_INIT_BLOCK() { // BlockRealize SBlockRealize::SBlockRealize(ffi::Array values, PrimExpr predicate, SBlock block, Span span) { - CHECK_EQ(block->iter_vars.size(), values.size()) - << "ValueError: BlockRealize needs to have the same number of iter_vars and binding values"; - CHECK(predicate.dtype().is_bool() || predicate.dtype() == DataType::UInt(1)) - << "TypeError: Expect Block.predicate to be a bool expression"; + TVM_FFI_CHECK_EQ(block->iter_vars.size(), values.size(), ValueError) + << "BlockRealize needs to have the same number of iter_vars and binding values"; + TVM_FFI_CHECK(predicate.dtype().is_bool() || predicate.dtype() == DataType::UInt(1), TypeError) + << "Expect Block.predicate to be a bool expression"; ObjectPtr node = ffi::make_object(); node->iter_values = std::move(values); node->predicate = std::move(predicate); diff --git a/src/tir/ir/stmt_functor.cc b/src/tir/ir/stmt_functor.cc index ef91f128bc8e..a3e59914c13f 100644 --- a/src/tir/ir/stmt_functor.cc +++ b/src/tir/ir/stmt_functor.cc @@ -564,8 +564,9 @@ class IRSubstitute : public StmtExprMutator { // uses void variables for lambda parameters (since exact types are not known yet). if (!var.dtype().is_void()) { PrimExpr ret_ex = Downcast(ret.value()); - ICHECK(ret_ex.dtype() == var.dtype()) << "substituting " << var << ":" << var.dtype() - << " -> " << ret_ex << ":" << ret_ex.dtype(); + TVM_FFI_ICHECK(ret_ex.dtype() == var.dtype()) + << "substituting " << var << ":" << var.dtype() << " -> " << ret_ex << ":" + << ret_ex.dtype(); } return ret.value(); } @@ -607,7 +608,7 @@ class IRSubstitute : public StmtExprMutator { } PrimExpr new_buffer_var_expr = VisitExpr(buf->data); - CHECK(new_buffer_var_expr->IsInstance()) + TVM_FFI_ICHECK(new_buffer_var_expr->IsInstance()) << "Buffer " << buf << " uses backing allocation " << buf->data << ", which was substituted into the expression " << new_buffer_var_expr << ". " << "However, this expression is of type " << new_buffer_var_expr->GetTypeKey() @@ -702,8 +703,8 @@ void PreOrderVisit(const ObjectRef& stmt_or_expr, } else if (auto expr = stmt_or_expr.as()) { visitor(expr.value()); } else { - LOG(FATAL) << "InternalError: PreOrderVisit does not accept object with type: " - << stmt_or_expr->GetTypeKey(); + TVM_FFI_THROW(InternalError) << "PreOrderVisit does not accept object with type: " + << stmt_or_expr->GetTypeKey(); } } diff --git a/src/tir/ir/tir_visitor_with_path.cc b/src/tir/ir/tir_visitor_with_path.cc index 3aa91484b922..2e9968790f1b 100644 --- a/src/tir/ir/tir_visitor_with_path.cc +++ b/src/tir/ir/tir_visitor_with_path.cc @@ -194,7 +194,7 @@ void TIRVisitorWithPath::VisitStmt_(const AttrStmtNode* op, AccessPath path) { // symbolic shapes used within `buffer_view that are not already // defined. ffi::Array arr = Downcast>(op->node); - ICHECK_EQ(arr.size(), 2U); + TVM_FFI_ICHECK_EQ(arr.size(), 2U); Buffer buffer_view = Downcast(arr[0]); Buffer orig_buffer = Downcast(arr[1]); Visit(orig_buffer, path->Attr("node")->ArrayItem(1)); diff --git a/src/tir/ir/transform.cc b/src/tir/ir/transform.cc index 68b494d41144..a525539b8f70 100644 --- a/src/tir/ir/transform.cc +++ b/src/tir/ir/transform.cc @@ -108,7 +108,7 @@ PrimFuncPass::PrimFuncPass(std::function Module optimizations at the PrimFunc level. IRModule PrimFuncPassNode::operator()(IRModule mod, const PassContext& pass_ctx) const { - ICHECK(mod.defined()); + TVM_FFI_ICHECK(mod.defined()); std::vector deleted_list; IRModuleNode* mod_ptr = mod.CopyOnWrite(); diff --git a/src/tir/op/op.cc b/src/tir/op/op.cc index 51c0b64ed295..8c2b4bd85962 100644 --- a/src/tir/op/op.cc +++ b/src/tir/op/op.cc @@ -58,7 +58,8 @@ runtime::DataType GetRuntimeDataType(const Type& type) { } else if (IsVoidType(type)) { return DataType::Void(); } else { - LOG(FATAL) << "Type " << type << " does not have a corresponding runtime::DataType"; + TVM_FFI_THROW(InternalError) << "Type " << type + << " does not have a corresponding runtime::DataType"; } } @@ -75,10 +76,11 @@ Type GetType(const PrimExpr& expr) { if (auto* access = expr.as()) { if (access->op.same_as(builtin::tvm_access_ptr())) { - ICHECK(access->args.size()) << "Builtin tvm_access_ptr() may not have empty arguments"; + TVM_FFI_ICHECK(access->args.size()) + << "Builtin tvm_access_ptr() may not have empty arguments"; auto type_annotation = Downcast(access->args[0]); static auto builtin_op = Op::Get("tir.type_annotation"); - ICHECK(type_annotation->op.same_as(builtin_op)) + TVM_FFI_ICHECK(type_annotation->op.same_as(builtin_op)) << "Expected the first argument of builtin tvm_access_ptr() " << "to be a type annotation, but found " << type_annotation->op; return PointerType(PrimType(type_annotation->dtype)); @@ -87,11 +89,11 @@ Type GetType(const PrimExpr& expr) { if (auto* address_of = expr.as()) { if (address_of->op.same_as(builtin::address_of())) { - ICHECK_EQ(address_of->args.size(), 1) + TVM_FFI_ICHECK_EQ(address_of->args.size(), 1) << "Builtin address_of() expects a single argument, but received arguments " << address_of->args; auto* address = address_of->args[0].as(); - ICHECK(address) + TVM_FFI_ICHECK(address) << "Builtin address_of() expects the argument to be a BufferLoad, but received argument " << address_of->args[0]; @@ -141,8 +143,8 @@ void BroadcastToMatchLanes(PrimExpr& op_a, PrimExpr& op_b) { // NOLINT(*) // The public function with a quick checking path. void BinaryOpMatchTypes(PrimExpr& lhs, PrimExpr& rhs, Span span) { // NOLINT(*) - CHECK(lhs.defined()) << "ValueError: `lhs` is null in the binary operator"; - CHECK(rhs.defined()) << "ValueError: `rhs` is null in the binary operator"; + TVM_FFI_CHECK(lhs.defined(), ValueError) << "`lhs` is null in the binary operator"; + TVM_FFI_CHECK(rhs.defined(), ValueError) << "`rhs` is null in the binary operator"; if (lhs.dtype() == rhs.dtype()) return; BroadcastToMatchLanes(lhs, rhs); @@ -151,7 +153,7 @@ void BinaryOpMatchTypes(PrimExpr& lhs, PrimExpr& rhs, Span span) { // NOLINT(*) DataType ltype = lhs.dtype(); DataType rtype = rhs.dtype(); - ICHECK(ltype.is_scalable_vector() == rtype.is_scalable_vector()) + TVM_FFI_ICHECK(ltype.is_scalable_vector() == rtype.is_scalable_vector()) << "Can't match scalable and fixed length vectors"; bool lanes_match = false; @@ -162,7 +164,7 @@ void BinaryOpMatchTypes(PrimExpr& lhs, PrimExpr& rhs, Span span) { // NOLINT(*) lanes_match = ltype.lanes() == rtype.lanes(); } - ICHECK(lanes_match) << "Cannot match type " << ltype << " vs " << rtype; + TVM_FFI_ICHECK(lanes_match) << "Cannot match type " << ltype << " vs " << rtype; if (lhs.dtype() == rhs.dtype()) return; ltype = lhs.dtype(); @@ -243,12 +245,12 @@ void BinaryOpMatchTypes(PrimExpr& lhs, PrimExpr& rhs, Span span) { // NOLINT(*) } } else { LOG(INFO) << lhs << " " << rhs; - LOG(FATAL) << "Cannot match type " << ltype << " vs " << rtype; + TVM_FFI_THROW(InternalError) << "Cannot match type " << ltype << " vs " << rtype; } } PrimExpr ret(PrimExpr value, Span span) { - CHECK(value.defined()); + TVM_FFI_ICHECK(value.defined()); return tir::Call(value.dtype(), tir::builtin::ret(), {value}, span); } @@ -276,7 +278,7 @@ TVM_FFI_STATIC_INIT_BLOCK() { // maximum and min limits PrimExpr max_value(const DataType& dtype, Span span) { using namespace tir; - ICHECK_EQ(dtype.lanes(), 1); + TVM_FFI_ICHECK_EQ(dtype.lanes(), 1); if (dtype.is_int()) { if (dtype.bits() == 64) { return IntImm(dtype, std::numeric_limits::max(), span); @@ -330,16 +332,17 @@ PrimExpr max_value(const DataType& dtype, Span span) { } else if (dtype.is_float4()) { return FloatImm(dtype, 6.0, span); } - LOG(FATAL) << "Cannot decide max_value for type" << dtype; + TVM_FFI_THROW(InternalError) << "Cannot decide max_value for type" << dtype; } PrimExpr min_value(const DataType& dtype, Span span) { using namespace tir; - ICHECK_EQ(dtype.lanes(), 1); + TVM_FFI_ICHECK_EQ(dtype.lanes(), 1); if (datatype::Registry::Global()->GetTypeRegistered(dtype.code())) { // TODO(tkonolige): need to convert all registered min functions to use the span. auto f = datatype::GetMinFunc(dtype.code()); - ICHECK(f) << "No minimum function registered for custom dtype " << (unsigned int)dtype.code(); + TVM_FFI_ICHECK(f) << "No minimum function registered for custom dtype " + << (unsigned int)dtype.code(); // TODO(@hypercubestart) Document this change (and others associated with the overflowing // floatimm min bug) return (*f)(dtype.bits()).cast(); @@ -391,13 +394,13 @@ PrimExpr min_value(const DataType& dtype, Span span) { } else if (dtype.is_float4()) { return FloatImm(dtype, -6.0, span); } - LOG(FATAL) << "Cannot decide min_value for type" << dtype; + TVM_FFI_THROW(InternalError) << "Cannot decide min_value for type" << dtype; } // infinity PrimExpr infinity(const DataType& dtype, Span span) { using namespace tir; - ICHECK_EQ(dtype.lanes(), 1); + TVM_FFI_ICHECK_EQ(dtype.lanes(), 1); if (dtype.is_float()) { if (dtype.bits() == 64) { return FloatImm(dtype, std::numeric_limits::infinity(), span); @@ -405,7 +408,7 @@ PrimExpr infinity(const DataType& dtype, Span span) { return FloatImm(dtype, std::numeric_limits::infinity(), span); } } - LOG(FATAL) << "Cannot decide infinity for type " << dtype; + TVM_FFI_THROW(InternalError) << "Cannot decide infinity for type " << dtype; } namespace tir { @@ -442,7 +445,7 @@ PrimExpr cast(const DataType& t, PrimExpr value, Span span) { } else if (const FloatImmNode* op = value.as()) { return make_const(t, op->value, op->span); } - ICHECK(!value.dtype().is_handle()) << "Can't cast a handle to other types."; + TVM_FFI_ICHECK(!value.dtype().is_handle()) << "Can't cast a handle to other types."; return tir::Cast(t, value, span); } else { DataType vtype = t.element_of(); @@ -465,7 +468,7 @@ PrimExpr cast(const DataType& t, PrimExpr value, Span span) { return tir::Broadcast(value, t.lanes(), span); } } else { /* value is a vector */ - ICHECK(value.dtype().is_scalable_vector() == t.is_scalable_vector()); + TVM_FFI_ICHECK(value.dtype().is_scalable_vector() == t.is_scalable_vector()); bool lanes_match = false; if (value.dtype().is_scalable_vector()) { @@ -473,7 +476,7 @@ PrimExpr cast(const DataType& t, PrimExpr value, Span span) { } else { lanes_match = value.dtype().lanes() == t.lanes(); } - ICHECK(lanes_match); + TVM_FFI_ICHECK(lanes_match); if (const auto* broadcast = value.as()) { return tir::Broadcast(cast(vtype, broadcast->value, span), broadcast->lanes, span); } else if (const auto* ramp = value.as()) { @@ -492,9 +495,9 @@ PrimExpr cast(const DataType& t, PrimExpr value, Span span) { PrimExpr reinterpret(const DataType& t, PrimExpr value, Span span) { if (value.dtype() == t) return value; if (!t.is_scalable_vector() && !value.dtype().is_scalable_vector()) { - ICHECK(value.dtype().bits() * value.dtype().lanes() == t.bits() * t.lanes() || - ((value.dtype().is_float4_e2m1fn() || t.is_float4_e2m1fn()) && - value.dtype().bytes() * value.dtype().lanes() == t.bytes() * t.lanes())) + TVM_FFI_ICHECK(value.dtype().bits() * value.dtype().lanes() == t.bits() * t.lanes() || + ((value.dtype().is_float4_e2m1fn() || t.is_float4_e2m1fn()) && + value.dtype().bytes() * value.dtype().lanes() == t.bytes() * t.lanes())) << "Reinterpret requires size match " << t << " vs " << value.dtype(); } return tir::Call(t, tir::builtin::reinterpret(), {value}, span); @@ -544,8 +547,8 @@ PrimExpr div(PrimExpr a, PrimExpr b, Span span) { } PrimExpr truncdiv(PrimExpr a, PrimExpr b, Span span) { - ICHECK(a.dtype().is_int() || a.dtype().is_uint()) << a; - ICHECK(b.dtype().is_int() || b.dtype().is_uint()) << b; + TVM_FFI_ICHECK(a.dtype().is_int() || a.dtype().is_uint()) << a; + TVM_FFI_ICHECK(b.dtype().is_int() || b.dtype().is_uint()) << b; return div(a, b, span); } @@ -567,16 +570,16 @@ PrimExpr shapediv(PrimExpr a, PrimExpr b, Span span) { return ceildiv(a, b, span PrimExpr indexmod(PrimExpr a, PrimExpr b, Span span) { return floormod(a, b, span); } PrimExpr floordiv(PrimExpr a, PrimExpr b, Span span) { - ICHECK(a.dtype().is_int() || a.dtype().is_uint()) << a; - ICHECK(b.dtype().is_int() || b.dtype().is_uint()) << b; + TVM_FFI_ICHECK(a.dtype().is_int() || a.dtype().is_uint()) << a; + TVM_FFI_ICHECK(b.dtype().is_int() || b.dtype().is_uint()) << b; BinaryOpMatchTypes(a, b, span); if (auto ret = arith::TryConstFold(a, b)) return ret.value(); return tir::FloorDiv(a, b, span); } PrimExpr logaddexp(PrimExpr a, PrimExpr b, Span span) { - ICHECK(a.dtype().is_float()) << a; - ICHECK(b.dtype().is_float()) << b; + TVM_FFI_ICHECK(a.dtype().is_float()) << a; + TVM_FFI_ICHECK(b.dtype().is_float()) << b; BinaryOpMatchTypes(a, b, span); PrimExpr exp_sum = add(exp(a), exp(b)); PrimExpr log_exp_sum = log(exp_sum); @@ -584,16 +587,16 @@ PrimExpr logaddexp(PrimExpr a, PrimExpr b, Span span) { } PrimExpr ceildiv(PrimExpr a, PrimExpr b, Span span) { - ICHECK(a.dtype().is_int() || a.dtype().is_uint()) << a; - ICHECK(b.dtype().is_int() || b.dtype().is_uint()) << b; + TVM_FFI_ICHECK(a.dtype().is_int() || a.dtype().is_uint()) << a; + TVM_FFI_ICHECK(b.dtype().is_int() || b.dtype().is_uint()) << b; BinaryOpMatchTypes(a, b, span); if (auto ret = arith::TryConstFold(a + b - 1, b)) return ret.value(); return tir::FloorDiv(a + b - 1, b, span); } PrimExpr floormod(PrimExpr a, PrimExpr b, Span span) { - ICHECK(a.dtype().is_int() || a.dtype().is_uint()) << a; - ICHECK(b.dtype().is_int() || b.dtype().is_uint()) << b; + TVM_FFI_ICHECK(a.dtype().is_int() || a.dtype().is_uint()) << a; + TVM_FFI_ICHECK(b.dtype().is_int() || b.dtype().is_uint()) << b; BinaryOpMatchTypes(a, b, span); if (auto ret = arith::TryConstFold(a, b)) return ret.value(); return tir::FloorMod(a, b, span); @@ -627,7 +630,7 @@ PrimExpr max(PrimExpr a, PrimExpr b, Span span) { // if_then_else PrimExpr if_then_else(PrimExpr cond, PrimExpr true_value, PrimExpr false_value, Span span) { - ICHECK(cond.dtype() == DataType::Bool()) + TVM_FFI_ICHECK(cond.dtype() == DataType::Bool()) << "if_then_else only accept the condition to be boolean type."; BinaryOpMatchTypes(true_value, false_value, span); if (const IntImmNode* op = cond.as()) { @@ -694,36 +697,36 @@ PrimExpr not_equal(PrimExpr a, PrimExpr b, Span span) { namespace { void type_check_boolean_args(const PrimExpr& arg, const char* op) { - ICHECK(arg.dtype().is_bool()) << "Expected boolean argument for " << op << ", but received " - << arg << " of type " << arg.dtype(); + TVM_FFI_ICHECK(arg.dtype().is_bool()) << "Expected boolean argument for " << op + << ", but received " << arg << " of type " << arg.dtype(); } void type_check_boolean_args(const PrimExpr& lhs, const PrimExpr& rhs, const char* op) { - ICHECK(lhs.dtype().is_bool()) << "Expected boolean argument as LHS of " << op << ", but received " - << lhs << " of type " << lhs.dtype(); - ICHECK(rhs.dtype().is_bool()) << "Expected boolean argument as RHS of " << op << ", but received " - << rhs << " of type " << rhs.dtype(); + TVM_FFI_ICHECK(lhs.dtype().is_bool()) << "Expected boolean argument as LHS of " << op + << ", but received " << lhs << " of type " << lhs.dtype(); + TVM_FFI_ICHECK(rhs.dtype().is_bool()) << "Expected boolean argument as RHS of " << op + << ", but received " << rhs << " of type " << rhs.dtype(); } void type_check_int_or_bool_args(const PrimExpr& arg, const char* op) { - ICHECK(arg.dtype().is_int() || arg.dtype().is_uint() || arg.dtype().is_bool()) + TVM_FFI_ICHECK(arg.dtype().is_int() || arg.dtype().is_uint() || arg.dtype().is_bool()) << "Expected integer or boolean argument for " << op << ", but received " << arg << " of type " << arg.dtype(); } void type_check_integer_args(const PrimExpr& lhs, const PrimExpr& rhs, const char* op) { - ICHECK(lhs.dtype().is_int() || lhs.dtype().is_uint()) + TVM_FFI_ICHECK(lhs.dtype().is_int() || lhs.dtype().is_uint()) << "Expected integer argument as LHS of " << op << ", but received " << lhs << " of type " << lhs.dtype(); - ICHECK(rhs.dtype().is_int() || rhs.dtype().is_uint()) + TVM_FFI_ICHECK(rhs.dtype().is_int() || rhs.dtype().is_uint()) << "Expected integer argument as RHS of " << op << ", but received " << rhs << " of type " << rhs.dtype(); } void type_check_int_or_bool_args(const PrimExpr& lhs, const PrimExpr& rhs, const char* op) { - ICHECK(lhs.dtype().is_int() || lhs.dtype().is_uint() || lhs.dtype().is_bool()) + TVM_FFI_ICHECK(lhs.dtype().is_int() || lhs.dtype().is_uint() || lhs.dtype().is_bool()) << "Expected integer argument as LHS of " << op << ", but received " << lhs << " of type " << lhs.dtype(); - ICHECK(rhs.dtype().is_int() || rhs.dtype().is_uint() || rhs.dtype().is_bool()) + TVM_FFI_ICHECK(rhs.dtype().is_int() || rhs.dtype().is_uint() || rhs.dtype().is_bool()) << "Expected integer argument as RHS of " << op << ", but received " << rhs << " of type " << rhs.dtype(); } @@ -760,7 +763,7 @@ PrimExpr right_shift(PrimExpr a, PrimExpr b, Span span) { TVM_INDEX_CONST_PROPAGATION({ const DataType& rtype = a.dtype(); if (pb) - ICHECK(pb->value >= 0 && pb->value < rtype.bits()) + TVM_FFI_ICHECK(pb->value >= 0 && pb->value < rtype.bits()) << "Shift amount must be non-negative and less than " << rtype.bits() << " for type " << rtype; if (pa && pb) { @@ -782,7 +785,7 @@ PrimExpr left_shift(PrimExpr a, PrimExpr b, Span span) { TVM_INDEX_CONST_PROPAGATION({ const DataType& rtype = a.dtype(); if (pb) - ICHECK(pb->value >= 0 && pb->value < rtype.bits()) + TVM_FFI_ICHECK(pb->value >= 0 && pb->value < rtype.bits()) << "Shift amount must be non-negative and less than " << rtype.bits() << " for type " << rtype; if (pa && pb) return IntImm(rtype, (pa->value << pb->value), span); @@ -846,7 +849,7 @@ TVM_FFI_STATIC_INIT_BLOCK() { // pow PrimExpr pow(PrimExpr x, PrimExpr y, Span span) { BinaryOpMatchTypes(x, y, span); - ICHECK(x.dtype().is_float()) << "power only applies to float"; + TVM_FFI_ICHECK(x.dtype().is_float()) << "power only applies to float"; // If we detect pow(x, 3), suggest using x * x * x if (y.dtype().is_int()) { @@ -899,8 +902,8 @@ PrimExpr abs(PrimExpr x, Span span) { } else if (x.dtype().is_uint()) { return x; } else { - LOG(FATAL) << "Data type " << x.dtype() - << " not supported for absolute op. Skipping absolute op..."; + TVM_FFI_THROW(InternalError) << "Data type " << x.dtype() + << " not supported for absolute op. Skipping absolute op..."; return x; } } @@ -925,7 +928,8 @@ PrimExpr isnan(PrimExpr x, Span span) { return tir::Call(t, op, {x}, span); } } else { - LOG(FATAL) << "Data type " << x.dtype() << " not supported for isnan op. Skipping isnan op..."; + TVM_FFI_THROW(InternalError) << "Data type " << x.dtype() + << " not supported for isnan op. Skipping isnan op..."; } } @@ -938,7 +942,8 @@ PrimExpr isinf(PrimExpr x, Span span) { PrimExpr infX = infinity(x.dtype(), span); return abs(x, span) == infX && !isnan(x, span); } else { - LOG(FATAL) << "Data type " << x.dtype() << " not supported for finiteness ops. Skipping it..."; + TVM_FFI_THROW(InternalError) << "Data type " << x.dtype() + << " not supported for finiteness ops. Skipping it..."; } } @@ -998,7 +1003,7 @@ PrimExpr prod(PrimExpr source, ffi::Array rdom, ffi::Array in // fmod PrimExpr fmod(PrimExpr x, PrimExpr y, Span span) { BinaryOpMatchTypes(x, y, span); - ICHECK(x.dtype().is_float()) << "fmod only applies to float"; + TVM_FFI_ICHECK(x.dtype().is_float()) << "fmod only applies to float"; static auto op = Op::Get("tir.fmod"); return tir::Call(x.dtype(), op, {x, y}, span); } @@ -1159,9 +1164,10 @@ TVM_FFI_STATIC_INIT_BLOCK() { } else if (auto opt = args[0].try_cast()) { *ret = tir::make_const(args[1].cast(), *opt, args[2].cast()); } else { - LOG(FATAL) << "First argument to tvm.tir.const must be int, float, or bool, " - << "but instead received argument with type code " - << args[0].GetTypeKey(); + TVM_FFI_THROW(InternalError) + << "First argument to tvm.tir.const must be int, float, or bool, " + << "but instead received argument with type code " + << args[0].GetTypeKey(); } }) .def("node.LargeUIntImm", LargeUIntImm) diff --git a/src/tir/transform/annotate_device_regions.cc b/src/tir/transform/annotate_device_regions.cc index 47b3df5fdaa3..755adade0cf4 100644 --- a/src/tir/transform/annotate_device_regions.cc +++ b/src/tir/transform/annotate_device_regions.cc @@ -62,7 +62,7 @@ namespace transform { Pass AnnotateDeviceRegions() { auto pass_func = [](PrimFunc func, IRModule mod, PassContext ctx) -> PrimFunc { auto opt_target = func->GetAttr(tvm::attr::kTarget); - ICHECK(opt_target) << "AnnotateDeviceRegions: Require the target attribute"; + TVM_FFI_ICHECK(opt_target) << "AnnotateDeviceRegions: Require the target attribute"; Target target = opt_target.value(); if (target->GetHost()) { diff --git a/src/tir/transform/arg_binder.cc b/src/tir/transform/arg_binder.cc index 1b85d7d21132..de44c1449d0a 100644 --- a/src/tir/transform/arg_binder.cc +++ b/src/tir/transform/arg_binder.cc @@ -37,8 +37,8 @@ void BinderAddAssert(arith::Analyzer* ana, PrimExpr cond, const std::string& arg std::vector* asserts) { PrimExpr scond = ana->Simplify(cond); if (is_zero(scond)) { - LOG(FATAL) << "Bind have an unmet assertion: " << cond << ", " - << " on argument " << arg_name; + TVM_FFI_THROW(InternalError) << "Bind have an unmet assertion: " << cond << ", " + << " on argument " << arg_name; } if (!is_one(scond)) { std::ostringstream os; @@ -49,7 +49,7 @@ void BinderAddAssert(arith::Analyzer* ana, PrimExpr cond, const std::string& arg bool ArgBinder::Bind_(const PrimExpr& arg, const PrimExpr& value, const std::string& arg_name, bool with_lets) { - ICHECK_EQ(arg.dtype(), value.dtype()); + TVM_FFI_ICHECK_EQ(arg.dtype(), value.dtype()); if (const VarNode* v = arg.as()) { auto it = def_map_->find(v); if (it == def_map_->end()) { @@ -78,7 +78,7 @@ void ArgBinder::Bind(const PrimExpr& arg, const PrimExpr& value, const std::stri void ArgBinder::BindArray(const ffi::Array& arg, const ffi::Array& value, const std::string& arg_name) { - ICHECK_EQ(arg.size(), value.size()) << "Argument " << arg_name << " array size mismatch"; + TVM_FFI_ICHECK_EQ(arg.size(), value.size()) << "Argument " << arg_name << " array size mismatch"; for (size_t i = 0; i < arg.size(); ++i) { std::ostringstream os; os << arg_name << "[" << i << "]"; @@ -88,8 +88,9 @@ void ArgBinder::BindArray(const ffi::Array& arg, const ffi::Arraydtype, value->dtype) + TVM_FFI_ICHECK_EQ(arg.scope(), value.scope()) + << "Argument " << arg_name << " Buffer bind scope mismatch"; + TVM_FFI_ICHECK_EQ(arg->dtype, value->dtype) << "Argument " << arg_name << " Buffer bind data type mismatch"; if (value->data_alignment % arg->data_alignment != 0) { LOG(WARNING) << "Trying to bind buffer to another one with lower alignment requirement " @@ -100,7 +101,7 @@ void ArgBinder::BindBuffer(const Buffer& arg, const Buffer& value, const std::st if (value->elem_offset.defined()) { // bind pointer and offset. if (is_zero(arg->elem_offset)) { - ICHECK(is_zero(value->elem_offset)) + TVM_FFI_ICHECK(is_zero(value->elem_offset)) << "Trying to bind a Buffer with offset into one without offset " << " required elem_offset=" << arg->elem_offset << ", provided elem_offset=" << value->elem_offset; @@ -119,10 +120,10 @@ void ArgBinder::BindBuffer(const Buffer& arg, const Buffer& value, const std::st } if (arg->shape.size() < value->shape.size()) { - ICHECK(fuzzy_match) << "Argument " << arg_name << " size mismatch"; + TVM_FFI_ICHECK(fuzzy_match) << "Argument " << arg_name << " size mismatch"; size_t diff = value->shape.size() - arg->shape.size(); for (size_t i = 0; i < diff; ++i) { - ICHECK(is_one(analyzer_.Simplify(value->shape[i]))) + TVM_FFI_ICHECK(is_one(analyzer_.Simplify(value->shape[i]))) << "Argument " << arg_name << " shape mismatch" << arg->shape << " vs " << value->shape; } for (size_t i = 0; i < arg->shape.size(); ++i) { @@ -131,8 +132,8 @@ void ArgBinder::BindBuffer(const Buffer& arg, const Buffer& value, const std::st this->Bind(arg->shape[i], value->shape[i + diff], os.str()); } if (value->strides.size() != 0) { - ICHECK_EQ(arg->strides.size(), arg->shape.size()); - ICHECK_EQ(value->strides.size(), value->shape.size()); + TVM_FFI_ICHECK_EQ(arg->strides.size(), arg->shape.size()); + TVM_FFI_ICHECK_EQ(value->strides.size(), value->shape.size()); for (size_t i = 0; i < arg->strides.size(); ++i) { std::ostringstream os; os << arg_name << ".strides[" << i << "]"; diff --git a/src/tir/transform/dtype_conversion.cc b/src/tir/transform/dtype_conversion.cc index 85341981d3c0..84530a778d5c 100644 --- a/src/tir/transform/dtype_conversion.cc +++ b/src/tir/transform/dtype_conversion.cc @@ -36,14 +36,14 @@ PrimExpr DTypeConversion(PrimExpr src_value, DataType tgt_dtype, RoundingMode ro DataType src_dtype = src_value.dtype(); // Step 1: check dtype // The lanes of src dtype and target dtype must match. - CHECK_EQ(src_dtype.lanes(), tgt_dtype.lanes()) + TVM_FFI_ICHECK_EQ(src_dtype.lanes(), tgt_dtype.lanes()) << "The lanes for data type for source value must matches the target datatype."; auto is_floating_point = [](DataType dtype) { return dtype.is_float() || dtype.is_bfloat16() || dtype.is_float8() || dtype.is_float6() || dtype.is_float4(); }; // Both source dtype and target dtype should be floating point. - CHECK(is_floating_point(src_dtype) && is_floating_point(tgt_dtype)); + TVM_FFI_ICHECK(is_floating_point(src_dtype) && is_floating_point(tgt_dtype)); FloatConfig src_fp = FloatConfig::FromDataType(src_value.dtype()), tgt_fp = FloatConfig::FromDataType(tgt_dtype); int exponent_delta = tgt_fp.exponent - src_fp.exponent; @@ -54,7 +54,7 @@ PrimExpr DTypeConversion(PrimExpr src_value, DataType tgt_dtype, RoundingMode ro PrimExpr src_uint_value = ReinterpretAsUInt(src_value); if (mantissa_delta < 0) { // use rounding - CHECK(round_mode == RoundingMode::kHalfToEven) + TVM_FFI_ICHECK(round_mode == RoundingMode::kHalfToEven) << "Currently we only support HalfToEven rounding mode."; PrimExpr rounding_bias = ((src_uint_value >> (-mantissa_delta)) & 1) + make_const(src_uint, (int64_t(1) << (-mantissa_delta - 1)) - 1); diff --git a/src/tir/transform/dtype_conversion.h b/src/tir/transform/dtype_conversion.h index 310ed21c7648..bc258301fa6c 100644 --- a/src/tir/transform/dtype_conversion.h +++ b/src/tir/transform/dtype_conversion.h @@ -99,8 +99,8 @@ class FloatConfig { * \return The FloatConfig class containing internal floating point representation. */ static FloatConfig FromDataType(DataType dtype) { - CHECK(dtype.is_float() || dtype.is_bfloat16() || dtype.is_float8() || dtype.is_float6() || - dtype.is_float4()) + TVM_FFI_ICHECK(dtype.is_float() || dtype.is_bfloat16() || dtype.is_float8() || + dtype.is_float6() || dtype.is_float4()) << "FloatConfig is only applicable to floating point data types, got " << dtype << " instead."; if (dtype.is_float()) { @@ -147,7 +147,7 @@ class FloatConfig { // UE8M0 format, not consistent with IEEE-754 return FloatConfig(8, 0, 127, InftyStyle::kNone, NaNStyle::kAllOnes); default: - LOG(FATAL) << "Unknown float8 variant: " << dtype; + TVM_FFI_THROW(InternalError) << "Unknown float8 variant: " << dtype; } } else if (dtype.is_float6()) { // float6 switch (dtype.code()) { @@ -158,7 +158,7 @@ class FloatConfig { // E3M2 format, not consistent with IEEE-754 return FloatConfig(3, 2, 3, InftyStyle::kNone, NaNStyle::kNone); default: - LOG(FATAL) << "Unknown float6 variant: " << dtype; + TVM_FFI_THROW(InternalError) << "Unknown float6 variant: " << dtype; } } else { // float4 diff --git a/src/tir/transform/flatten_buffer.cc b/src/tir/transform/flatten_buffer.cc index 800177fa5ca9..87d770c93eb3 100644 --- a/src/tir/transform/flatten_buffer.cc +++ b/src/tir/transform/flatten_buffer.cc @@ -61,7 +61,7 @@ class BufferFlattener : public arith::IRMutatorWithAnalyzer { explicit BufferFlattener(arith::Analyzer* ana) : IRMutatorWithAnalyzer(ana) {} Stmt VisitStmt_(const SBlockNode* op) final { - ICHECK_EQ(op->match_buffers.size(), 0) + TVM_FFI_ICHECK_EQ(op->match_buffers.size(), 0) << "Unexpected MatchBufferRegion found during tir.transform.FlattenBuffer. " << "All MatchBufferRegion should be removed in tir.transform.LowerMatchBuffer."; @@ -124,7 +124,7 @@ class BufferFlattener : public arith::IRMutatorWithAnalyzer { Buffer flattened = GetFlattenedBuffer(buffer); return flattened->shape; } else { - ICHECK(decl_buffer->buffer->axis_separators.empty()) + TVM_FFI_ICHECK(decl_buffer->buffer->axis_separators.empty()) << "DeclBuffer node doesn't match Allocate extents, but also shouldn't be " "flattened to 1-d physical memory"; } @@ -193,7 +193,7 @@ class BufferFlattener : public arith::IRMutatorWithAnalyzer { // TODO(Lunderberg): Move the handling of boolean into a // dedicated pass. if (store_returns_bool) { - ICHECK_EQ(store->buffer->dtype, DataType::Int(8)) + TVM_FFI_ICHECK_EQ(store->buffer->dtype, DataType::Int(8)) << "Expected int8 backing array for boolean tensor"; auto writer = store.CopyOnWrite(); writer->value = tvm::cast(DataType::Int(8), store->value); @@ -210,7 +210,7 @@ class BufferFlattener : public arith::IRMutatorWithAnalyzer { // TODO(Lunderberg): Move the handling of boolean into a // dedicated pass. if (load_returns_bool) { - ICHECK_EQ(load->buffer->dtype, DataType::Int(8)) + TVM_FFI_ICHECK_EQ(load->buffer->dtype, DataType::Int(8)) << "Expected int8 backing array for boolean tensor"; load.CopyOnWrite()->dtype = DataType::Int(8); return tvm::cast(DataType::Bool(), load); @@ -227,7 +227,7 @@ class BufferFlattener : public arith::IRMutatorWithAnalyzer { template Node VisitBufferAccess(Node node) { - ICHECK(node->buffer.defined()); + TVM_FFI_ICHECK(node->buffer.defined()); auto flattened_indices = GetSimplifiedElemOffset(node->buffer, node->indices); Buffer flattened_buffer = GetFlattenedBuffer(node->buffer); @@ -255,7 +255,7 @@ class BufferFlattener : public arith::IRMutatorWithAnalyzer { ffi::Array flattened_max = GetSimplifiedElemOffset(orig_buf, max_values); ffi::Array flattened_ranges; - ICHECK_EQ(flattened_min.size(), flattened_max.size()); + TVM_FFI_ICHECK_EQ(flattened_min.size(), flattened_max.size()); for (size_t i = 0; i < flattened_min.size(); i++) { flattened_ranges.push_back(Range(flattened_min[i], flattened_max[i] + 1)); } diff --git a/src/tir/transform/force_narrow_index_to_i32.cc b/src/tir/transform/force_narrow_index_to_i32.cc index e908d351255c..46ff7739d2ad 100644 --- a/src/tir/transform/force_narrow_index_to_i32.cc +++ b/src/tir/transform/force_narrow_index_to_i32.cc @@ -38,8 +38,9 @@ class Int32DTypeNarrower : public IndexDataTypeNormalizer { // Check if the integer parameter buffers have dtype other than int32. for (auto it : func->buffer_map) { if (it.second->dtype.is_int() && it.second->dtype.bits() > 32) { - LOG(FATAL) << "The buffer " << it.second << " in the function buffer map has dtype " - << it.second->dtype << ". The function is " << func; + TVM_FFI_THROW(InternalError) + << "The buffer " << it.second << " in the function buffer map has dtype " + << it.second->dtype << ". The function is " << func; } } @@ -54,7 +55,7 @@ class Int32DTypeNarrower : public IndexDataTypeNormalizer { PrimExpr VisitExpr_(const IntImmNode* op) final { // ignore the enabled condition and always rewrite i64 if (op->dtype == DataType::Int(64)) { - ICHECK_LE(op->value, Downcast(max_value(target_data_type_))->value); + TVM_FFI_ICHECK_LE(op->value, Downcast(max_value(target_data_type_))->value); return IntImm(DataType::Int(32), op->value); } return ffi::GetRef(op); @@ -65,8 +66,9 @@ class Int32DTypeNarrower : public IndexDataTypeNormalizer { // Check if the allocated integer buffers have dtype other than int32. for (const Buffer& buf : block_->alloc_buffers) { if (buf->dtype.is_int() && buf->dtype.bits() > 32) { - LOG(FATAL) << "The buffer " << buf << " allocated in the function has dtype " << buf->dtype - << ". The function is " << func_; + TVM_FFI_THROW(InternalError) + << "The buffer " << buf << " allocated in the function has dtype " << buf->dtype + << ". The function is " << func_; } } return block_; diff --git a/src/tir/transform/inline_private_functions.cc b/src/tir/transform/inline_private_functions.cc index 030aac3c75dd..a44cdb37add2 100644 --- a/src/tir/transform/inline_private_functions.cc +++ b/src/tir/transform/inline_private_functions.cc @@ -224,12 +224,12 @@ class PrimFuncInliner : StmtExprMutator { Stmt InlineArguments(const GlobalVar& gvar, PrimFunc callee, const ffi::Array& args) const { - CHECK_EQ(callee->params.size(), args.size()) + TVM_FFI_ICHECK_EQ(callee->params.size(), args.size()) << "Callee " << gvar << " accepts " << callee->params.size() << " parameters (" << callee->params << "), but is called with " << args.size() << " arguments (" << args << ")"; - ICHECK(callee->buffer_map.empty()) + TVM_FFI_ICHECK(callee->buffer_map.empty()) << "Inlining of PrimFuncs with buffer arguments is not yet supported, " << "but callee " << gvar << " has non-empty buffer map " << callee->buffer_map; diff --git a/src/tir/transform/ir_utils.cc b/src/tir/transform/ir_utils.cc index 15917d891e4e..4134202b7887 100644 --- a/src/tir/transform/ir_utils.cc +++ b/src/tir/transform/ir_utils.cc @@ -42,47 +42,47 @@ Stmt MergeNest(const std::vector& nest, Stmt body) { Stmt s = *ri; if (const auto* for_ = s.as()) { auto n = ffi::make_object(*for_); - ICHECK(is_no_op(n->body)); + TVM_FFI_ICHECK(is_no_op(n->body)); n->body = body; body = Stmt(n); } else if (const auto* let = s.as()) { auto n = ffi::make_object(*let); - ICHECK(is_no_op(n->body)); + TVM_FFI_ICHECK(is_no_op(n->body)); n->body = body; body = Stmt(n); } else if (const auto* attr = s.as()) { auto n = ffi::make_object(*attr); - ICHECK(is_no_op(n->body)); + TVM_FFI_ICHECK(is_no_op(n->body)); n->body = body; body = Stmt(n); } else if (const auto* ite = s.as()) { auto n = ffi::make_object(*ite); - ICHECK(is_no_op(n->then_case)); - ICHECK(!n->else_case); + TVM_FFI_ICHECK(is_no_op(n->then_case)); + TVM_FFI_ICHECK(!n->else_case); n->then_case = body; body = Stmt(n); } else if (const auto* seq = s.as()) { auto n = ffi::make_object(*seq); - ICHECK(n->size() != 0 && is_no_op(n->seq[n->size() - 1])); + TVM_FFI_ICHECK(n->size() != 0 && is_no_op(n->seq[n->size() - 1])); n->seq.Set(n->size() - 1, body); body = Stmt(n); } else if (const auto* assert_ = s.as()) { auto n = ffi::make_object(*assert_); - ICHECK(is_no_op(n->body)); + TVM_FFI_ICHECK(is_no_op(n->body)); n->body = body; body = Stmt(n); } else if (const auto* alloc = s.as()) { auto n = ffi::make_object(*alloc); - ICHECK(is_no_op(n->body)); + TVM_FFI_ICHECK(is_no_op(n->body)); n->body = body; body = Stmt(n); } else if (const auto* decl_buffer = s.as()) { auto n = ffi::make_object(*decl_buffer); - ICHECK(is_no_op(n->body)); + TVM_FFI_ICHECK(is_no_op(n->body)); n->body = body; body = Stmt(n); } else { - LOG(FATAL) << "not supported nest type"; + TVM_FFI_THROW(InternalError) << "not supported nest type"; } } return body; @@ -527,16 +527,16 @@ Stmt ConvertSSA(Stmt stmt) { return IRConvertSSA()(std::move(stmt)); } ffi::String GetPtrStorageScope(Var buffer_var) { const auto* ptr_type = buffer_var->type_annotation.as(); - ICHECK(ptr_type) << "The provided variable is not of pointer type"; + TVM_FFI_ICHECK(ptr_type) << "The provided variable is not of pointer type"; return ptr_type->storage_scope; } ffi::Array GetBufferAllocationShape(const Buffer& buffer) { ffi::Array alloc_shape = buffer->shape; if (buffer->strides.size()) { - ICHECK_EQ(buffer->shape.size(), buffer->strides.size()); + TVM_FFI_ICHECK_EQ(buffer->shape.size(), buffer->strides.size()); for (size_t i = buffer->strides.size() - 1; i > 0; --i) { - ICHECK( + TVM_FFI_ICHECK( arith::Analyzer().CanProveEqual(floormod(buffer->strides[i - 1], buffer->strides[i]), 0)); alloc_shape.Set(i, buffer->strides[i - 1] / buffer->strides[i]); } @@ -548,7 +548,7 @@ ffi::Array ConvertIndices(const MatchBufferRegion& match_buffer, const ffi::Array& indices) { const Buffer& target = match_buffer->buffer; const BufferRegion& source = match_buffer->source; - ICHECK_EQ(indices.size(), target->shape.size()); + TVM_FFI_ICHECK_EQ(indices.size(), target->shape.size()); arith::Analyzer analyzer; ffi::Array result; @@ -556,7 +556,7 @@ ffi::Array ConvertIndices(const MatchBufferRegion& match_buffer, size_t offset = source->region.size() - indices.size(); for (size_t i = 0; i < offset; ++i) { const Range& range = source->region[i]; - ICHECK(analyzer.CanProve(range->extent == 1)); + TVM_FFI_ICHECK(analyzer.CanProve(range->extent == 1)); result.push_back(range->min); } for (size_t i = 0; i < indices.size(); ++i) { @@ -570,7 +570,7 @@ ffi::Array ConvertIndices(const MatchBufferRegion& match_buffer, Region ConvertRegion(const MatchBufferRegion& match_buffer, const Region& region) { const Buffer& target = match_buffer->buffer; const BufferRegion& source = match_buffer->source; - ICHECK_EQ(region.size(), target->shape.size()); + TVM_FFI_ICHECK_EQ(region.size(), target->shape.size()); arith::Analyzer analyzer; Region result; @@ -578,7 +578,7 @@ Region ConvertRegion(const MatchBufferRegion& match_buffer, const Region& region size_t offset = source->region.size() - region.size(); for (size_t i = 0; i < offset; ++i) { const Range& source_range = source->region[i]; - ICHECK(analyzer.CanProve(source_range->extent == 1)); + TVM_FFI_ICHECK(analyzer.CanProve(source_range->extent == 1)); result.push_back(Range::FromMinExtent(source_range->min, 1)); } for (size_t i = 0; i < region.size(); ++i) { @@ -719,7 +719,7 @@ void ConditionalBoundsContext::ExitWithScope() { } else { // recover bound for free var auto hint_it = hint_map_->find(var); - ICHECK(hint_it != hint_map_->end()); + TVM_FFI_ICHECK(hint_it != hint_map_->end()); if (p.second.IsNothing()) { hint_map_->erase(hint_it); } else { @@ -730,9 +730,9 @@ void ConditionalBoundsContext::ExitWithScope() { } std::pair GetAsyncWaitAttributes(const AttrStmtNode* op) { - ICHECK(op && op->attr_key == tir::attr::async_wait_queue_scope); + TVM_FFI_ICHECK(op && op->attr_key == tir::attr::async_wait_queue_scope); auto inner = op->body.as(); - ICHECK(inner && inner->attr_key == tir::attr::async_wait_inflight_count); + TVM_FFI_ICHECK(inner && inner->attr_key == tir::attr::async_wait_inflight_count); return std::make_pair(op->value, inner->value); } @@ -765,7 +765,7 @@ class StorageAlignCollector : public StmtVisitor { int buffer_index = storage_align_tuple.get<0>(); // the first buffer idx info is meaningless for allocate // stmt and should set as negative intentionally. - ICHECK_EQ(buffer_index, -1); + TVM_FFI_ICHECK_EQ(buffer_index, -1); storage_align_[op->buffer_var].push_back(storage_align_tuple); } } @@ -786,7 +786,7 @@ int Stoi(const std::string& str) { try { return std::stoi(str); } catch (std::invalid_argument& e) { - LOG(FATAL) << "Cannot convert \"" << str << "\" to int"; + TVM_FFI_THROW(InternalError) << "Cannot convert \"" << str << "\" to int"; throw; } } diff --git a/src/tir/transform/ir_utils.h b/src/tir/transform/ir_utils.h index c8d72e3b14f2..51e229332167 100644 --- a/src/tir/transform/ir_utils.h +++ b/src/tir/transform/ir_utils.h @@ -153,11 +153,11 @@ inline Stmt TVMStructSet(Var handle, int index, builtin::TVMStructFieldKind kind * \return The corresponding API type. */ inline DataType APIType(DataType t) { - ICHECK(!t.is_void()) << "Cannot pass void type through packed API."; + TVM_FFI_ICHECK(!t.is_void()) << "Cannot pass void type through packed API."; if (t.is_handle()) return t; - ICHECK_EQ(t.lanes(), 1) << "Cannot pass vector type through packed API."; + TVM_FFI_ICHECK_EQ(t.lanes(), 1) << "Cannot pass vector type through packed API."; if (t.is_bool() || t.is_uint() || t.is_int()) return DataType::Int(64); - ICHECK(t.is_float()); + TVM_FFI_ICHECK(t.is_float()); return DataType::Float(64); } @@ -184,7 +184,7 @@ inline int GetTempAllocaAlignment(DataType type, int32_t const_size) { * \return the PrimExpr that represents the constant */ inline PrimExpr ConstInt32(size_t index) { - ICHECK_LE(index, std::numeric_limits::max()); + TVM_FFI_ICHECK_LE(index, std::numeric_limits::max()); return make_const(DataType::Int(32), static_cast(index)); } @@ -297,7 +297,7 @@ struct FragmentInfo { } else if (scope == "wmma.accumulator") { return m * n; } else { - ICHECK(0); + TVM_FFI_ICHECK(0); throw; } } diff --git a/src/tir/transform/lower_custom_datatypes.cc b/src/tir/transform/lower_custom_datatypes.cc index 90725fe5befc..2e7264bc5f3d 100644 --- a/src/tir/transform/lower_custom_datatypes.cc +++ b/src/tir/transform/lower_custom_datatypes.cc @@ -54,9 +54,10 @@ class CustomDatatypesLowerer : public StmtExprMutator { PrimExpr expr = StmtExprMutator::VisitExpr_(op); if (to_be_lowered) { auto lower = datatype::GetCastLowerFunc(target_, type_code, src_type_code); - ICHECK(lower) << "Cast lowering function for target " << target_ << " destination type " - << static_cast(type_code) << " source type " - << static_cast(src_type_code) << " not found"; + TVM_FFI_ICHECK(lower) << "Cast lowering function for target " << target_ + << " destination type " << static_cast(type_code) + << " source type " << static_cast(src_type_code) + << " not found"; return (*lower)(expr).cast(); } return expr; @@ -67,8 +68,8 @@ class CustomDatatypesLowerer : public StmtExprMutator { auto e = ffi::GetRef(imm); if (datatype::Registry::Global()->GetTypeRegistered(type_code)) { auto lower = datatype::GetFloatImmLowerFunc(target_, type_code); - ICHECK(lower) << "FloatImm lowering function for target " << target_ << " type " - << static_cast(type_code) << " not found"; + TVM_FFI_ICHECK(lower) << "FloatImm lowering function for target " << target_ << " type " + << static_cast(type_code) << " not found"; return (*lower)(e).cast(); } return e; @@ -188,29 +189,29 @@ class CustomDatatypesLowerer : public StmtExprMutator { call = expr.as(); if (to_be_lowered) { auto op = call->op.as(); - ICHECK(op != nullptr) << "Lowering non-intrinsic Calls not implemented"; + TVM_FFI_ICHECK(op != nullptr) << "Lowering non-intrinsic Calls not implemented"; auto lower = datatype::GetIntrinLowerFunc(target_, op->name, call->dtype.code()); - ICHECK(lower) << "Intrinsic lowering function for target " << target_ << ", intrinsic name " - << op->name << ", type " << static_cast(call->dtype.code()) - << " not found"; + TVM_FFI_ICHECK(lower) << "Intrinsic lowering function for target " << target_ + << ", intrinsic name " << op->name << ", type " + << static_cast(call->dtype.code()) << " not found"; return (*lower)(expr).cast(); } return expr; } -#define TVM_DEFINE_MUTATE_CUSTOM_DTYPE(OP, NodeName) \ - PrimExpr VisitExpr_(const NodeName* op) final { \ - auto type_code = op->dtype.code(); \ - bool to_be_lowered = datatype::Registry::Global()->GetTypeRegistered(type_code); \ - PrimExpr expr = StmtExprMutator::VisitExpr_(op); \ - op = expr.as(); \ - if (to_be_lowered) { \ - auto lower = datatype::Get##OP##LowerFunc(target_, type_code); \ - ICHECK(lower) << #OP " lowering function for target " << target_ << " type " \ - << static_cast(type_code) << " not found"; \ - return (*lower)(expr).cast(); \ - } \ - return expr; \ +#define TVM_DEFINE_MUTATE_CUSTOM_DTYPE(OP, NodeName) \ + PrimExpr VisitExpr_(const NodeName* op) final { \ + auto type_code = op->dtype.code(); \ + bool to_be_lowered = datatype::Registry::Global()->GetTypeRegistered(type_code); \ + PrimExpr expr = StmtExprMutator::VisitExpr_(op); \ + op = expr.as(); \ + if (to_be_lowered) { \ + auto lower = datatype::Get##OP##LowerFunc(target_, type_code); \ + TVM_FFI_ICHECK(lower) << #OP " lowering function for target " << target_ << " type " \ + << static_cast(type_code) << " not found"; \ + return (*lower)(expr).cast(); \ + } \ + return expr; \ } TVM_DEFINE_MUTATE_CUSTOM_DTYPE(Add, AddNode); @@ -243,7 +244,7 @@ Pass LowerCustomDatatypes() { auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) { auto* n = f.CopyOnWrite(); auto target = f->GetAttr(tvm::attr::kTarget); - ICHECK(target.defined()) << "LowerCustomDatatypes: Require the target attribute"; + TVM_FFI_ICHECK(target.defined()) << "LowerCustomDatatypes: Require the target attribute"; n->body = CustomDatatypesLowerer(target.value()->kind->name)(std::move(n->body)); return f; diff --git a/src/tir/transform/lower_device_kernel_launch.cc b/src/tir/transform/lower_device_kernel_launch.cc index fcf85ce6b445..363945a01b29 100644 --- a/src/tir/transform/lower_device_kernel_launch.cc +++ b/src/tir/transform/lower_device_kernel_launch.cc @@ -91,23 +91,23 @@ class DeviceInfoCollector : public StmtVisitor { private: PrimExpr GetArgument(const ffi::String& launch_param) const { if (launch_param == tvm::runtime::launch_param::kUseDynamicSharedMemoryTag) { - CHECK(dyn_shmem_size.defined()) + TVM_FFI_ICHECK(dyn_shmem_size.defined()) << "Compute kernel requires launch parameter \"" << launch_param << "\", but PrimFunc did not contain Allocate node with shared dynamic scope."; return dyn_shmem_size.value(); } auto extent = thread_extent.Get(launch_param); - CHECK(extent) << "Compute kernel requires launch parameter \"" << launch_param - << "\", but PrimFunc does not contain AttrStmt \"" << attr::thread_extent - << "\" defining this thread extent"; + TVM_FFI_ICHECK(extent) << "Compute kernel requires launch parameter \"" << launch_param + << "\", but PrimFunc does not contain AttrStmt \"" << attr::thread_extent + << "\" defining this thread extent"; return extent.value(); } void VisitStmt_(const AttrStmtNode* op) final { if (op->attr_key == attr::thread_extent) { IterVar iv = Downcast(op->node); - ICHECK_NE(iv->thread_tag.length(), 0U); + TVM_FFI_ICHECK_NE(iv->thread_tag.length(), 0U); // thread_extent can appear multiple times // use the first appearance as def. if (!defined_thread.count(iv.get())) { @@ -123,8 +123,9 @@ class DeviceInfoCollector : public StmtVisitor { void VisitStmt_(const AllocateNode* op) final { auto storage_scope = runtime::StorageScope::Create(GetPtrStorageScope(op->buffer_var)); if (storage_scope.rank == runtime::StorageRank::kShared && storage_scope.tag == ".dyn") { - ICHECK(!dyn_shmem_size.defined()) << "Only one dynamic shared memory allocation is allowed."; - ICHECK_GT(op->extents.size(), 0); + TVM_FFI_ICHECK(!dyn_shmem_size.defined()) + << "Only one dynamic shared memory allocation is allowed."; + TVM_FFI_ICHECK_GT(op->extents.size(), 0); PrimExpr dyn_size = Integer(1); for (const auto& extent : op->extents) { @@ -159,9 +160,9 @@ class ReturnRemover : public StmtExprMutator { Stmt VisitStmt_(const EvaluateNode* op) override { if (auto* call = op->value.as()) { if (call->op.same_as(builtin::ret())) { - ICHECK_EQ(call->args.size(), 1); + TVM_FFI_ICHECK_EQ(call->args.size(), 1); auto as_int = call->args[0].as(); - ICHECK(as_int && as_int->value == 0) + TVM_FFI_ICHECK(as_int && as_int->value == 0) << "Device kernel may only contain successful return, T.ret(0)"; return Evaluate(0); } @@ -171,7 +172,8 @@ class ReturnRemover : public StmtExprMutator { PrimExpr VisitExpr_(const CallNode* op) override { if (op->op.same_as(builtin::ret())) { - LOG(FATAL) << "Call to builtin::ret() should only appear within an Evaluate node"; + TVM_FFI_THROW(InternalError) + << "Call to builtin::ret() should only appear within an Evaluate node"; } return Parent::VisitExpr_(op); } @@ -186,9 +188,9 @@ class DeviceKernelMutator : public StmtExprMutator { : device_info_map_(std::move(device_info_map)) {} PrimFunc RewriteKernelLaunchSite(const GlobalVar& gvar, PrimFunc func) { - ICHECK(!current_target_.defined()); + TVM_FFI_ICHECK(!current_target_.defined()); auto it = device_info_map_.find(gvar.get()); - ICHECK(it != device_info_map_.end()); + TVM_FFI_ICHECK(it != device_info_map_.end()); current_target_ = it->second.target; auto body = VisitStmt(func->body); @@ -203,7 +205,7 @@ class DeviceKernelMutator : public StmtExprMutator { PrimFunc UpdateKernelAttributes(const GlobalVar& gvar, PrimFunc func) const { bool is_kernel_launch = device_kernel_launch_.count(gvar.get()); bool is_call_extern = extern_function_call_.count(gvar.get()); - CHECK(!is_kernel_launch || !is_call_extern) + TVM_FFI_ICHECK(!is_kernel_launch || !is_call_extern) << "Function " << gvar << " has multiple callees, " << "and would need to be lowered into a call_extern at some call sites, " << "and a device kernel launch at others. " @@ -244,7 +246,7 @@ class DeviceKernelMutator : public StmtExprMutator { if (!gvar) return node; auto it = device_info_map_.find(gvar); - ICHECK(it != device_info_map_.end()) + TVM_FFI_ICHECK(it != device_info_map_.end()) << "CallNode attempted subroutine call to " << gvar->name_hint << ", but " << gvar->name_hint << " did not appear within the IRModule"; const KernelInfo& dev_info = it->second; @@ -274,7 +276,7 @@ class DeviceKernelMutator : public StmtExprMutator { return Call(node->dtype, builtin::call_extern(), args); } - ICHECK(dev_info.launch_params.defined()) + TVM_FFI_ICHECK(dev_info.launch_params.defined()) << "CallNode attempted kernel launch to " << gvar->name_hint << " on target " << dev_info.target << ", but subroutine " << gvar->name_hint << " did not have the tir::attr::kKernelLaunchParams attribute " @@ -287,7 +289,7 @@ class DeviceKernelMutator : public StmtExprMutator { // expressions that are valid within the caller. ffi::Map param_map = [&]() { ffi::Map param_map; - CHECK_EQ(node->args.size(), dev_info.params.size()) + TVM_FFI_ICHECK_EQ(node->args.size(), dev_info.params.size()) << "Function " << gvar->name_hint << " accepts " << dev_info.params.size() << " arguments as input, but is called using " << node->args.size() << " arguments"; for (size_t i = 0; i < node->args.size(); i++) { diff --git a/src/tir/transform/lower_intrin.cc b/src/tir/transform/lower_intrin.cc index e80f2f133e66..fc85e541038f 100644 --- a/src/tir/transform/lower_intrin.cc +++ b/src/tir/transform/lower_intrin.cc @@ -73,7 +73,7 @@ class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer { if (f != nullptr) { PrimExpr e = ffi::GetRef(op); PrimExpr r = f(e); - ICHECK(r.defined()) << "intrinsic rule must always return valid Expr"; + TVM_FFI_ICHECK(r.defined()) << "intrinsic rule must always return valid Expr"; if (!r.same_as(e)) { r = this->VisitExpr(r); if (r.defined()) { @@ -104,7 +104,7 @@ class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer { if (op == nullptr) return ret; int shift; const DataType& dtype = op->dtype; - ICHECK(dtype.is_int() || dtype.is_uint()); + TVM_FFI_ICHECK(dtype.is_int() || dtype.is_uint()); if (support_bitwise_op_ && is_const_power_of_two_integer(op->b, &shift)) { // lower to right shift if possible. @@ -165,7 +165,7 @@ class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer { // Lower floordiv to native truncdiv. int shift; const DataType& dtype = op->dtype; - ICHECK(dtype.is_int() || dtype.is_uint()); + TVM_FFI_ICHECK(dtype.is_int() || dtype.is_uint()); if (support_bitwise_op_ && is_const_power_of_two_integer(op->b, &shift)) { // lower to masking if possible. @@ -330,7 +330,7 @@ class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer { return std::nullopt; } int64_t c_value = ((b_value - 1) - const_int_bound_a->min_value) / b_value; - ICHECK_GT(c_value, 0); + TVM_FFI_ICHECK_GT(c_value, 0); // NOTE: the c_value * b_value risks in overflow if (c_value > max_value_of_dtype / b_value) return std::nullopt; // need to check if the offset numerator will overflow @@ -361,7 +361,7 @@ Pass LowerIntrin() { auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) { auto* n = f.CopyOnWrite(); auto target = f->GetAttr(tvm::attr::kTarget); - ICHECK(target.defined()) << "LowerIntrin: Require the target attribute"; + TVM_FFI_ICHECK(target.defined()) << "LowerIntrin: Require the target attribute"; arith::Analyzer analyzer; auto mtriple = target.value()->GetAttr("mtriple", ""); n->body = diff --git a/src/tir/transform/lower_tvm_builtin.cc b/src/tir/transform/lower_tvm_builtin.cc index 4add282272a4..2520caeb7b1c 100644 --- a/src/tir/transform/lower_tvm_builtin.cc +++ b/src/tir/transform/lower_tvm_builtin.cc @@ -108,9 +108,9 @@ class BuiltinLower : public StmtExprMutator { } void AssertMaxIsValid() const { - ICHECK((max_sizes.shape_stack >= run_sizes.shape_stack) || - (max_sizes.array_stack >= run_sizes.array_stack) || - (max_sizes.arg_stack >= run_sizes.arg_stack)); + TVM_FFI_ICHECK((max_sizes.shape_stack >= run_sizes.shape_stack) || + (max_sizes.array_stack >= run_sizes.array_stack) || + (max_sizes.arg_stack >= run_sizes.arg_stack)); } }; @@ -132,7 +132,7 @@ class BuiltinLower : public StmtExprMutator { precheck.VisitStmt(stmt); - ICHECK_EQ(precheck.alloca_scope_.size(), 1); + TVM_FFI_ICHECK_EQ(precheck.alloca_scope_.size(), 1); return precheck.alloca_scope_[0].max_sizes; } @@ -174,7 +174,7 @@ class BuiltinLower : public StmtExprMutator { stmt = this->VisitStmt(stmt); - ICHECK(!alloca_scope_.empty()); + TVM_FFI_ICHECK(!alloca_scope_.empty()); alloca_scope_.pop_back(); return stmt; @@ -193,11 +193,11 @@ class BuiltinLower : public StmtExprMutator { // make_stack_shape only happens within a call_packed. // We could relax this in the future if we want to // introduce root scope as a separate scope - ICHECK_EQ(alloca_scope_.size(), scope_size) + TVM_FFI_ICHECK_EQ(alloca_scope_.size(), scope_size) << "alloca_scope_ length is different before and after recursion"; - ICHECK_EQ(scope.run_sizes.shape_stack, -1) + TVM_FFI_ICHECK_EQ(scope.run_sizes.shape_stack, -1) << "Expect no tvm_stack_make_shape outside of CallNodes"; - ICHECK_EQ(scope.run_sizes.array_stack, 0) + TVM_FFI_ICHECK_EQ(scope.run_sizes.array_stack, 0) << "Expect no tvm_stack_make_array outside of CallNodes"; } @@ -251,8 +251,8 @@ class BuiltinLower : public StmtExprMutator { // set total_bytes to uint64 to avoid overflow total_bytes = total_bytes * op->extents[i]; } - ICHECK(device_type_) << "Unknown device type in current IR"; - ICHECK(device_id_) << "Unknown device id in current IR"; + TVM_FFI_ICHECK(device_type_) << "Unknown device type in current IR"; + TVM_FFI_ICHECK(device_id_) << "Unknown device id in current IR"; Stmt throw_last_error = Evaluate(Call(DataType::Int(32), builtin::tvm_throw_last_error(), {})); Stmt alloc_nullptr_check = IfThenElse( @@ -361,14 +361,14 @@ class BuiltinLower : public StmtExprMutator { } StringImm GetDeviceMethodName(const char* method_name) const { - CHECK(device_type_) << "Method " << method_name << " requires the device type, " - << "but occurred outside of a \"device_type\" annotation"; + TVM_FFI_ICHECK(device_type_) << "Method " << method_name << " requires the device type, " + << "but occurred outside of a \"device_type\" annotation"; auto as_int = device_type_.as(); - CHECK(as_int) << "Method " << method_name - << " requires the device type to be a DLDeviceType enum value, " - << "but was instead the expression " << device_type_ << " with type " - << device_type_.value()->GetTypeKey(); + TVM_FFI_ICHECK(as_int) << "Method " << method_name + << " requires the device type to be a DLDeviceType enum value, " + << "but was instead the expression " << device_type_ << " with type " + << device_type_.value()->GetTypeKey(); ffi::String device_name = runtime::DLDeviceType2Str(as_int->value); return StringImm("device_api." + device_name + "." + method_name); @@ -416,7 +416,7 @@ class BuiltinLower : public StmtExprMutator { // call shape PrimExpr MakeShape(const CallNode* op) { // if args.size() == 0, it represents a scalar shape () - ICHECK(!alloca_scope_.empty()); + TVM_FFI_ICHECK(!alloca_scope_.empty()); auto& scope = alloca_scope_.back(); auto& prep_seq = prep_seq_stack_.back(); if (scope.run_sizes.shape_stack == -1) { @@ -435,7 +435,7 @@ class BuiltinLower : public StmtExprMutator { } // make array PrimExpr MakeArray(const CallNode* op) { - ICHECK(!alloca_scope_.empty()); + TVM_FFI_ICHECK(!alloca_scope_.empty()); auto& scope = alloca_scope_.back(); auto& prep_seq = prep_seq_stack_.back(); @@ -471,8 +471,8 @@ class BuiltinLower : public StmtExprMutator { } prep_seq.emplace_back(TVMStructSet(scope.stack_array, idx, builtin::kArrByteOffset, cast(DataType::UInt(64), byte_offset))); - ICHECK(device_type_) << "Unknown device type in current IR"; - ICHECK(device_id_) << "Unknown device id in current IR"; + TVM_FFI_ICHECK(device_type_) << "Unknown device type in current IR"; + TVM_FFI_ICHECK(device_id_) << "Unknown device id in current IR"; prep_seq.emplace_back(TVMStructSet(scope.stack_array, idx, builtin::kArrDeviceId, cast(DataType::Int(32), device_id_.value()))); prep_seq.emplace_back(TVMStructSet(scope.stack_array, idx, builtin::kArrDeviceType, @@ -505,7 +505,7 @@ class BuiltinLower : public StmtExprMutator { } else if (api_dtype.is_handle()) { return ffi::TypeIndex::kTVMFFIOpaquePtr; } else { - LOG(FATAL) << "Unsupported type: " << api_dtype; + TVM_FFI_THROW(InternalError) << "Unsupported type: " << api_dtype; } }(); @@ -606,8 +606,8 @@ class BuiltinLower : public StmtExprMutator { } Stmt MakeNdMemAllocWithScope(const LetStmtNode* let, const CallNode* call) { - ICHECK(device_type_) << "Unknown device type in current IR"; - ICHECK(device_id_) << "Unknown device id in current IR"; + TVM_FFI_ICHECK(device_type_) << "Unknown device type in current IR"; + TVM_FFI_ICHECK(device_id_) << "Unknown device id in current IR"; Stmt throw_last_error = Evaluate(Call(DataType::Int(32), builtin::tvm_throw_last_error(), {})); PrimExpr storage_scope = call->args[0]; diff --git a/src/tir/transform/lower_warp_memory.cc b/src/tir/transform/lower_warp_memory.cc index 7da7dca7a63a..073759967b10 100644 --- a/src/tir/transform/lower_warp_memory.cc +++ b/src/tir/transform/lower_warp_memory.cc @@ -119,7 +119,7 @@ class WarpStoreCoeffFinder : private StmtExprVisitor { UpdatePattern(op->args[4]); } else if (op->op.same_as(builtin::mma_fill()) && op->args[1].as() == buffer_) { auto* local_size = op->args[0].as(); - ICHECK(local_size) << "Integer expected for the first argument of mma_fill"; + TVM_FFI_ICHECK(local_size) << "Integer expected for the first argument of mma_fill"; warp_coeff_ = local_size->value; } @@ -132,13 +132,13 @@ class WarpStoreCoeffFinder : private StmtExprVisitor { return; } - ICHECK_EQ(op->indices.size(), 1) << "Expected flat memory to use as warp memory. " - << "Has FlattenBuffer been run?"; + TVM_FFI_ICHECK_EQ(op->indices.size(), 1) << "Expected flat memory to use as warp memory. " + << "Has FlattenBuffer been run?"; PrimExpr index = op->indices[0]; if (op->value.dtype().lanes() != 1) { arith::PVar base; - ICHECK(arith::ramp(base, 1, op->value.dtype().lanes()).Match(index)) + TVM_FFI_ICHECK(arith::ramp(base, 1, op->value.dtype().lanes()).Match(index)) << "LowerWarpMemory failed due to store index=" << index << ", can only handle continuous store"; UpdatePattern(base.Eval()); @@ -151,20 +151,20 @@ class WarpStoreCoeffFinder : private StmtExprVisitor { void UpdatePattern(const PrimExpr& index) { ffi::Array m = arith::DetectLinearEquation(index, {warp_index_}); - ICHECK_EQ(m.size(), 2U) + TVM_FFI_ICHECK_EQ(m.size(), 2U) << "LowerWarpMemory failed. Could not simplify the store index `" << index << "` into the form ax + by + cz + ... Warp memory is approximated by storing values in " "thread local registers and shuffling values between these registers. Currently only " "linear equation indices are supported."; PrimExpr mcoeff = analyzer_->canonical_simplify(m[0]); const auto* mcoeff_as_int = mcoeff.as(); - ICHECK(mcoeff_as_int && mcoeff_as_int->value > 0) + TVM_FFI_ICHECK(mcoeff_as_int && mcoeff_as_int->value > 0) << "LowerWarpMemory failed due to store index=" << index << ", require positive constant coefficient on warp index " << warp_index_ << " but get " << mcoeff; if (warp_coeff_ != 0) { - ICHECK_EQ(warp_coeff_, mcoeff_as_int->value) + TVM_FFI_ICHECK_EQ(warp_coeff_, mcoeff_as_int->value) << "LowerWarpMemory failed due to two different store coefficient to warp index"; } else { warp_coeff_ = mcoeff_as_int->value; @@ -188,7 +188,7 @@ class WarpIndexFinder : private StmtVisitor { // find the warp co-efficient and the shuffle width in the statement std::pair Find(const Stmt& stmt) { this->VisitStmt(stmt); - ICHECK(warp_index_.defined()) + TVM_FFI_ICHECK(warp_index_.defined()) << "Cannot find warp index(threadIdx.x) within the scope of warp memory"; return std::make_pair(warp_index_->var, width_); } @@ -200,14 +200,14 @@ class WarpIndexFinder : private StmtVisitor { IterVar iv = Downcast(op->node); if (iv->thread_tag == "threadIdx.x") { auto* value_as_int = op->value.as(); - ICHECK(value_as_int && value_as_int->value <= warp_size_ && - warp_size_ % value_as_int->value == 0) + TVM_FFI_ICHECK(value_as_int && value_as_int->value <= warp_size_ && + warp_size_ % value_as_int->value == 0) << "Expect threadIdx.x 's size to be no larger than, and a factor of" << " warp size(" << warp_size_ << ")" << " to enable warp memory" << " but get " << op->value << " instead"; if (warp_index_.defined()) { - ICHECK(warp_index_.same_as(iv)) + TVM_FFI_ICHECK(warp_index_.same_as(iv)) << "Find two instance of " << warp_index_->thread_tag << " in the same kernel. " << "Please create it using thread_axis once and reuse the axis " << "across multiple binds in the same kernel"; @@ -236,7 +236,7 @@ class WarpAccessRewriter : protected StmtExprMutator { Stmt Rewrite(const AllocateNode* op) { buffer_ = op->buffer_var.get(); int alloc_size = op->ConstantAllocationSize(); - ICHECK_GT(alloc_size, 0) << "warp memory only support constant alloc size"; + TVM_FFI_ICHECK_GT(alloc_size, 0) << "warp memory only support constant alloc size"; alloc_size *= op->dtype.lanes(); std::tie(warp_index_, width_) = WarpIndexFinder(warp_size_).Find(op->body); warp_coeff_ = WarpStoreCoeffFinder(buffer_, warp_index_, analyzer_).Find(op->body); @@ -244,7 +244,7 @@ class WarpAccessRewriter : protected StmtExprMutator { // Align the local memory size. The number of elements may not // be a multiple of width_ * warp_coeff_; round it up. int factor = width_ * warp_coeff_; - ICHECK_NE(factor, 0) << "Divide by zero"; + TVM_FFI_ICHECK_NE(factor, 0) << "Divide by zero"; warp_group_ = (alloc_size + (factor - 1)) / factor; alloc_size = warp_group_ * factor; @@ -285,7 +285,7 @@ class WarpAccessRewriter : protected StmtExprMutator { } PrimExpr VisitExpr_(const VarNode* op) override { - ICHECK(op != buffer_) << "Cannot access address of warp memory directly"; + TVM_FFI_ICHECK(op != buffer_) << "Cannot access address of warp memory directly"; return StmtExprMutator::VisitExpr_(op); } @@ -293,8 +293,8 @@ class WarpAccessRewriter : protected StmtExprMutator { auto store = Downcast(StmtExprMutator::VisitStmt_(op)); if (store->buffer->data.get() == buffer_) { - ICHECK_EQ(store->indices.size(), 1) << "Expected flat memory to use as warp memory. " - << "Has FlattenBuffer been run?"; + TVM_FFI_ICHECK_EQ(store->indices.size(), 1) << "Expected flat memory to use as warp memory. " + << "Has FlattenBuffer been run?"; auto [local_index, group] = SplitIndexByGroup(store->indices[0]); (void)group; // https://gcc.gnu.org/bugzilla/show_bug.cgi?id=81767 @@ -313,12 +313,13 @@ class WarpAccessRewriter : protected StmtExprMutator { return load; } - ICHECK_EQ(op->indices.size(), 1) << "Expected flat memory to use as warp memory. " - << "Has FlattenBuffer been run?"; + TVM_FFI_ICHECK_EQ(op->indices.size(), 1) << "Expected flat memory to use as warp memory. " + << "Has FlattenBuffer been run?"; auto [local_index, group] = SplitIndexByGroup(op->indices[0]); // invariance: local index must do not contain warp id - ICHECK(!UsesVar(local_index, [this](const VarNode* var) { return var == warp_index_.get(); })) + TVM_FFI_ICHECK( + !UsesVar(local_index, [this](const VarNode* var) { return var == warp_index_.get(); })) << "LowerWarpMemory failed to rewrite load to shuffle for index " << op->indices[0] << " local_index=" << local_index; @@ -341,7 +342,7 @@ class WarpAccessRewriter : protected StmtExprMutator { std::pair SplitIndexByGroup(const PrimExpr& index) { if (index.dtype().lanes() != 1) { arith::PVar base; - ICHECK(arith::ramp(base, 1, index.dtype().lanes()).Match(index)); + TVM_FFI_ICHECK(arith::ramp(base, 1, index.dtype().lanes()).Match(index)); auto [local_index, group] = SplitIndexByGroup(base.Eval()); local_index = Ramp(local_index, make_const(local_index.dtype(), 1), index.dtype().lanes()); @@ -396,7 +397,7 @@ class BindVarBoundInfo : public StmtVisitor { void VisitStmt_(const AttrStmtNode* op) { if (op->attr_key == attr::thread_extent || op->attr_key == attr::virtual_thread) { IterVar iv = Downcast(op->node); - ICHECK_NE(iv->thread_tag.length(), 0U); + TVM_FFI_ICHECK_NE(iv->thread_tag.length(), 0U); if (!var_dom_.count(iv->var.get())) { Range dom = Range::FromMinExtent(0, op->value); var_dom_[iv->var.get()] = dom; @@ -452,7 +453,7 @@ Pass LowerWarpMemory() { auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) { auto* n = f.CopyOnWrite(); auto target = f->GetAttr(tvm::attr::kTarget); - ICHECK(target.defined()) << "LowerWarpMemory: Require the target attribute"; + TVM_FFI_ICHECK(target.defined()) << "LowerWarpMemory: Require the target attribute"; int warp_size = target.value()->GetAttr("thread_warp_size", 1).value().IntValue(); WarpMemoryRewriter warp_memory_rewriter(warp_size); auto stmt = warp_memory_rewriter.Rewrite(std::move(n->body)); diff --git a/src/tir/transform/make_packed_api.cc b/src/tir/transform/make_packed_api.cc index e8a1b564a43b..4c35b3fdd891 100644 --- a/src/tir/transform/make_packed_api.cc +++ b/src/tir/transform/make_packed_api.cc @@ -58,11 +58,11 @@ class ReturnRewriter : public StmtMutator { Stmt VisitStmt_(const EvaluateNode* node) override { Stmt ret = StmtMutator::VisitStmt_(node); const EvaluateNode* eval = ret.as(); - ICHECK(eval); + TVM_FFI_ICHECK(eval); if (const CallNode* call = eval->value.as()) { if (call->op.same_as(builtin::ret())) { - ICHECK_EQ(in_parallel_, 0) << "tir.ret cannot be used in parallel scope."; - ICHECK_EQ(call->args.size(), 1) << "tir.ret expect a single argument."; + TVM_FFI_ICHECK_EQ(in_parallel_, 0) << "tir.ret cannot be used in parallel scope."; + TVM_FFI_ICHECK_EQ(call->args.size(), 1) << "tir.ret expect a single argument."; ret = WriteToOut(call->args[0]); } } @@ -94,7 +94,7 @@ class ReturnRewriter : public StmtMutator { info.type_index = ffi::TypeIndex::kTVMFFINone; info.expr = val; } else { - LOG(FATAL) << "data type " << dtype << " not supported yet"; + TVM_FFI_THROW(InternalError) << "data type " << dtype << " not supported yet"; } return info; } @@ -210,8 +210,9 @@ PrimFunc MakePackedAPI(PrimFunc func) { Target target = [&]() { auto opt = func->GetAttr(tvm::attr::kTarget); - ICHECK(opt) << "MakePackedAPI required the function to be annotated with tvm::attr::kTarget (" - << tvm::attr::kTarget << "), but the function only has attributes " << func->attrs; + TVM_FFI_ICHECK(opt) + << "MakePackedAPI required the function to be annotated with tvm::attr::kTarget (" + << tvm::attr::kTarget << "), but the function only has attributes " << func->attrs; return opt.value(); }(); int target_device_type = target->GetTargetDeviceType(); @@ -326,7 +327,7 @@ PrimFunc MakePackedAPI(PrimFunc func) { tvm::tir::StringImm(msg.str()), nop)); arg_value = f_load_arg_value(param.dtype(), i); } else { - ICHECK(dtype.is_float()); + TVM_FFI_ICHECK(dtype.is_float()); std::ostringstream msg; msg << name_hint << ": Expect arg[" << i << "] to be float"; seq_init.emplace_back(AssertStmt(type_index == ffi::TypeIndex::kTVMFFIFloat || @@ -398,8 +399,9 @@ PrimFunc MakePackedAPI(PrimFunc func) { func_ptr->params = args; ffi::Array undefined = UndefinedVars(func_ptr->body, func_ptr->params); - ICHECK_EQ(undefined.size(), 0) << "In PrimFunc " << name_hint << " variables " << undefined - << " are used, but are not passed in as API arguments"; + TVM_FFI_ICHECK_EQ(undefined.size(), 0) + << "In PrimFunc " << name_hint << " variables " << undefined + << " are used, but are not passed in as API arguments"; func_ptr->buffer_map = ffi::Map(); func_ptr->ret_type = PrimType(DataType::Int(32)); diff --git a/src/tir/transform/narrow_datatype.cc b/src/tir/transform/narrow_datatype.cc index 8d03f8c157ed..5ecd46d7be05 100644 --- a/src/tir/transform/narrow_datatype.cc +++ b/src/tir/transform/narrow_datatype.cc @@ -122,7 +122,7 @@ class DataTypeVisitor final : public StmtExprVisitor { void VisitStmt_(const AttrStmtNode* op) { if (op->attr_key == attr::thread_extent || op->attr_key == attr::virtual_thread) { IterVar iv = Downcast(op->node); - ICHECK_NE(iv->thread_tag.length(), 0U); + TVM_FFI_ICHECK_NE(iv->thread_tag.length(), 0U); analyzer_.Bind(iv->var, Range::FromMinExtent(0, op->value)); vextent_[iv->var.as()] = op->value.dtype(); StmtExprVisitor::VisitStmt_(op); @@ -251,8 +251,8 @@ class NarrowDataTypeRewriter : public IndexDataTypeRewriter { if (is_enabled_ && visitor_.vmap.find(op) != visitor_.vmap.end()) { PrimExpr e = Parent::VisitExpr_(op); const CastNode* new_op = e.as(); - ICHECK(new_op != nullptr) << "Expected type to be CastNode" - << ", but get " << e->GetTypeKey(); + TVM_FFI_ICHECK(new_op != nullptr) << "Expected type to be CastNode" + << ", but get " << e->GetTypeKey(); PrimExpr new_value = new_op->value; DataType cast_type = visitor_.vmap[op]; if (new_value.dtype() != cast_type) { diff --git a/src/tir/transform/remap_thread_axis.cc b/src/tir/transform/remap_thread_axis.cc index c7184e07a036..4a47e43d06d1 100644 --- a/src/tir/transform/remap_thread_axis.cc +++ b/src/tir/transform/remap_thread_axis.cc @@ -42,7 +42,7 @@ class ThreadAxisRewriter : private StmtExprMutator { Stmt VisitStmt_(const AttrStmtNode* op) final { if (op->attr_key == attr::thread_extent) { IterVar iv = Downcast(op->node); - ICHECK_NE(iv->thread_tag.length(), 0U); + TVM_FFI_ICHECK_NE(iv->thread_tag.length(), 0U); auto it = tmap_.find(iv->thread_tag); if (it != tmap_.end()) { const IterVar& new_iv = it->second; @@ -50,7 +50,7 @@ class ThreadAxisRewriter : private StmtExprMutator { if (!vmap_.count(v)) { vmap_[v] = new_iv->var; } else { - ICHECK(vmap_[v].same_as(new_iv->var)); + TVM_FFI_ICHECK(vmap_[v].same_as(new_iv->var)); } Stmt body = this->VisitStmt(op->body); return AttrStmt(new_iv, op->attr_key, op->value, body); @@ -77,7 +77,7 @@ PrimFunc RemapThreadAxis(PrimFunc func, ffi::Map thread_ma } if (auto opt = func->GetAttr>(tir::attr::kKernelLaunchParams)) { - ICHECK(opt != nullptr) << "Require attribute " << tir::attr::kKernelLaunchParams; + TVM_FFI_ICHECK(opt != nullptr) << "Require attribute " << tir::attr::kKernelLaunchParams; auto launch_params = opt.value(); // replace the thread axis attribute for (size_t i = 0; i < launch_params.size(); ++i) { diff --git a/src/tir/transform/remove_no_op.cc b/src/tir/transform/remove_no_op.cc index 6cc80535085f..010d189d8930 100644 --- a/src/tir/transform/remove_no_op.cc +++ b/src/tir/transform/remove_no_op.cc @@ -119,7 +119,7 @@ class NoOpRemover : public arith::IRMutatorWithAnalyzer { // For example, a wait count 1 - i can be negative after loop unrolling. // We assume that such wait is a nop. auto inner = op->body.as(); - ICHECK(inner); + TVM_FFI_ICHECK(inner); return Parent::VisitStmt(inner->body); } } diff --git a/src/tir/transform/simplify.cc b/src/tir/transform/simplify.cc index 648710584a20..f06e52f32864 100644 --- a/src/tir/transform/simplify.cc +++ b/src/tir/transform/simplify.cc @@ -322,7 +322,7 @@ class StmtSimplifier : public IRMutatorWithAnalyzer { ffi::Optional ProveCondition(PrimExpr condition) const { condition = Substitute(condition, non_inlined_bindings_); if (config_->propagate_knowns_to_prove_conditional) { - ICHECK(touch_pattern_.has_value()); + TVM_FFI_ICHECK(touch_pattern_.has_value()); condition = touch_pattern_->SimplifyInContext(condition, current_stmt_.value(), analyzer_); } else { condition = analyzer_->Simplify(condition); diff --git a/src/tir/transform/storage_rewrite.cc b/src/tir/transform/storage_rewrite.cc index c6ce86955858..e7acd2853381 100644 --- a/src/tir/transform/storage_rewrite.cc +++ b/src/tir/transform/storage_rewrite.cc @@ -112,10 +112,10 @@ class LinearAccessPatternFinder final : public StmtExprVisitor { const VarNode* buffer_var = op->buffer->data.get(); auto it = alloc_info_.find(buffer_var); if (it != alloc_info_.end() && it->second.alloc) { - ICHECK_LT(it->second.level, scope_.size()); + TVM_FFI_ICHECK_LT(it->second.level, scope_.size()); scope_[it->second.level].touched.push_back(buffer_var); - ICHECK_EQ(op->buffer->axis_separators.size() + 1, it->second.num_physical_dimensions) + TVM_FFI_ICHECK_EQ(op->buffer->axis_separators.size() + 1, it->second.num_physical_dimensions) << "Buffer " << op->buffer->name << " is allocated with " << it->second.num_physical_dimensions << " physical dimensions, but is accessed as having " @@ -138,10 +138,11 @@ class LinearAccessPatternFinder final : public StmtExprVisitor { const VarNode* buffer_var = op->buffer->data.get(); auto it = alloc_info_.find(buffer_var); if (it != alloc_info_.end() && it->second.alloc) { - ICHECK_LT(it->second.level, scope_.size()) << "Load memory in places other than store."; + TVM_FFI_ICHECK_LT(it->second.level, scope_.size()) + << "Load memory in places other than store."; scope_[it->second.level].touched.push_back(buffer_var); - ICHECK_EQ(op->buffer->axis_separators.size() + 1, it->second.num_physical_dimensions) + TVM_FFI_ICHECK_EQ(op->buffer->axis_separators.size() + 1, it->second.num_physical_dimensions) << "Buffer " << op->buffer->name << " is allocated with " << it->second.num_physical_dimensions << " physical dimensions, but is accessed as having " @@ -165,7 +166,7 @@ class LinearAccessPatternFinder final : public StmtExprVisitor { // Directly reference to the variable count as a read. auto it = alloc_info_.find(buf); if (it != alloc_info_.end() && it->second.alloc) { - ICHECK_LT(it->second.level, scope_.size()) << " buf=" << buf->name_hint; + TVM_FFI_ICHECK_LT(it->second.level, scope_.size()) << " buf=" << buf->name_hint; scope_[it->second.level].touched.push_back(buf); } } @@ -183,11 +184,11 @@ class LinearAccessPatternFinder final : public StmtExprVisitor { e.touched = std::move(scope_.back().touched); scope_.pop_back(); int64_t end_index = static_cast(linear_seq_.size()); - ICHECK_GT(end_index, begin_index); + TVM_FFI_ICHECK_GT(end_index, begin_index); e.scope_pair_offset = begin_index - end_index; linear_seq_.push_back(e); // record the pointer to end index. - ICHECK_NE(end_index, 0U); + TVM_FFI_ICHECK_NE(end_index, 0U); linear_seq_[begin_index].scope_pair_offset = end_index - begin_index; } @@ -339,7 +340,7 @@ class InplaceOpVerifier : public StmtExprVisitor { result_ = false; return; } - ICHECK_EQ(store_->indices.size(), op->indices.size()) + TVM_FFI_ICHECK_EQ(store_->indices.size(), op->indices.size()) << "Store/Load occur to the same buffer " << buf->name_hint << " with differing number of indices"; for (size_t i = 0; i < store_->indices.size(); i++) { @@ -421,7 +422,7 @@ class StoragePlanRewriter : public StmtExprMutator { auto key = buf.get(); auto it = buffer_remap_.find(key); if (it != buffer_remap_.end()) { - ICHECK_EQ(it->second->data.get(), new_backing_array.get()) + TVM_FFI_ICHECK_EQ(it->second->data.get(), new_backing_array.get()) << "Cannot remap buffer " << buf->name << " to use backing array " << new_backing_array->name_hint << ", previously used backing array " << it->second->data->name_hint; @@ -458,7 +459,7 @@ class StoragePlanRewriter : public StmtExprMutator { } PrimExpr VisitExpr_(const CallNode* op) final { if (op->op.same_as(builtin::tvm_access_ptr())) { - ICHECK_EQ(op->args.size(), 5U); + TVM_FFI_ICHECK_EQ(op->args.size(), 5U); DataType dtype = op->args[0].dtype(); const VarNode* buffer = op->args[1].as(); auto it = alloc_map_.find(buffer); @@ -469,7 +470,7 @@ class StoragePlanRewriter : public StmtExprMutator { PrimExpr offset = this->VisitExpr(op->args[2]); PrimExpr extent = this->VisitExpr(op->args[3]); uint64_t elem_bits = dtype.bits() * dtype.lanes(); - ICHECK_EQ(se->bits_offset % elem_bits, 0U); + TVM_FFI_ICHECK_EQ(se->bits_offset % elem_bits, 0U); if (se->bits_offset != 0) { offset = make_const(offset.dtype(), se->bits_offset / elem_bits) + offset; } @@ -503,7 +504,7 @@ class StoragePlanRewriter : public StmtExprMutator { } Stmt VisitStmt_(const ForNode* op) final { - ICHECK(op->kind != ForKind::kVectorized) << "VectorizeLoop before LiftStorageAlloc"; + TVM_FFI_ICHECK(op->kind != ForKind::kVectorized) << "VectorizeLoop before LiftStorageAlloc"; // remake all the allocation at the attach scope. if (attach_map_.count(op)) { auto& svec = attach_map_[op]; @@ -597,7 +598,7 @@ class StoragePlanRewriter : public StmtExprMutator { PrimExpr RemapIndex(DataType dtype, PrimExpr index, StorageEntry* e) { if (e->bits_offset == 0) return index; uint64_t elem_bits = dtype.bits(); - ICHECK_EQ(e->bits_offset % elem_bits, 0U); + TVM_FFI_ICHECK_EQ(e->bits_offset % elem_bits, 0U); return make_const(index.dtype(), e->bits_offset / elem_bits) + index; } // Prepare the new allocations @@ -614,7 +615,7 @@ class StoragePlanRewriter : public StmtExprMutator { for (size_t i = 0; i < vec.size(); ++i) { StorageEntry* e = vec[i]; if (IsSpecialTaggedMemory(e->scope)) { - ICHECK_NE(e->const_nbits, 0U) << "Special tagged memory must be const size"; + TVM_FFI_ICHECK_NE(e->const_nbits, 0U) << "Special tagged memory must be const size"; for (size_t j = 0; j < i; ++j) { if (e->scope == vec[j]->scope) { vec[j]->merged_children.push_back(e); @@ -672,7 +673,7 @@ class StoragePlanRewriter : public StmtExprMutator { // Build a merged allocation PrimExpr combo_size; for (const AllocateNode* op : e->allocs) { - ICHECK_EQ(op->extents.size(), 1) + TVM_FFI_ICHECK_EQ(op->extents.size(), 1) << "Buffer var " << op->buffer_var->name_hint << " was identified as a re-usable allocation, but has " << op->extents.size() << " physical dimensions. " @@ -714,9 +715,9 @@ class StoragePlanRewriter : public StmtExprMutator { } // New allocation for merged data void NewAllocTagMerged(StorageEntry* e) { - ICHECK_NE(e->scope.tag.length(), 0U); + TVM_FFI_ICHECK_NE(e->scope.tag.length(), 0U); // allocate with element type. - ICHECK_NE(e->const_nbits, 0U); + TVM_FFI_ICHECK_NE(e->const_nbits, 0U); uint64_t total_bits = e->const_nbits; // By default, align to 32 bits. size_t align = 32; @@ -727,8 +728,8 @@ class StoragePlanRewriter : public StmtExprMutator { } e->alloc_var = e->allocs[0]->buffer_var; for (StorageEntry* child : e->merged_children) { - ICHECK_NE(child->const_nbits, 0U); - ICHECK_NE(total_bits, 0U); + TVM_FFI_ICHECK_NE(child->const_nbits, 0U); + TVM_FFI_ICHECK_NE(total_bits, 0U); child->bits_offset = total_bits; child->alloc_var = e->alloc_var; total_bits += child->const_nbits; @@ -771,7 +772,7 @@ class StoragePlanRewriter : public StmtExprMutator { } void PlanNewScope(const Object* op) { if (thread_scope_ != nullptr) { - ICHECK(thread_scope_ == op); + TVM_FFI_ICHECK(thread_scope_ == op); // erase all memory atatched to this scope. for (auto it = const_free_map_.begin(); it != const_free_map_.end();) { if (it->second->attach_scope_ == op) { @@ -813,7 +814,7 @@ class StoragePlanRewriter : public StmtExprMutator { bool detect_inplace = detect_inplace_ && (it->second.gen.size() <= 2); for (const VarNode* var : it->second.gen) { - ICHECK(alloc_info.count(var)); + TVM_FFI_ICHECK(alloc_info.count(var)); const AllocEntry& entry = alloc_info.at(var); const AllocateNode* alloc = entry.alloc; auto storage_scope = StorageScope::Create(GetPtrStorageScope(ffi::GetRef(var))); @@ -858,7 +859,7 @@ class StoragePlanRewriter : public StmtExprMutator { attr::IsPragmaKey(op->attr_key)) { PlanNewScope(op); } else { - ICHECK(op->attr_key == attr::extern_scope); + TVM_FFI_ICHECK(op->attr_key == attr::extern_scope); } } else if (s.stmt->IsInstance()) { const auto* op = static_cast(s.stmt); @@ -885,7 +886,7 @@ class StoragePlanRewriter : public StmtExprMutator { // Allocate new storage entry. StorageEntry* NewAlloc(const AllocateNode* op, const Object* attach_scope, const StorageScope& scope, size_t const_nbits) { - ICHECK(op != nullptr); + TVM_FFI_ICHECK(op != nullptr); // Re-use not successful, allocate a new buffer. auto entry = std::make_unique(); entry->attach_scope_ = attach_scope; @@ -900,7 +901,7 @@ class StoragePlanRewriter : public StmtExprMutator { StorageEntry* FindAlloc(const AllocateNode* op, const Object* attach_scope, const StorageScope& scope, size_t num_physical_dimensions, bool enable_reuse, bool reuse_require_exact_matched_dtype) { - ICHECK(op != nullptr); + TVM_FFI_ICHECK(op != nullptr); // skip plan for local variable, // compiler can do a better job with register allocation. const uint64_t match_range = 16; @@ -975,9 +976,9 @@ class StoragePlanRewriter : public StmtExprMutator { // simulated free. void Free(const VarNode* var) { auto it = alloc_map_.find(var); - ICHECK(it != alloc_map_.end()); + TVM_FFI_ICHECK(it != alloc_map_.end()); StorageEntry* e = it->second; - ICHECK_NE(e->allocs.size(), 0U); + TVM_FFI_ICHECK_NE(e->allocs.size(), 0U); // disable reuse of small arrays, they will be lowered to registers in LLVM // This rules only apply if we are using non special memory @@ -1197,9 +1198,9 @@ class VectorTypeAccessChecker : public StmtExprVisitor { } else if (allow_untyped_pointers_) { OnArrayDeclaration(let_var, let_var->dtype, 0, BufferVarInfo::kLetNode); } else { - LOG(FATAL) << "Let statement of variable " << let_var->name_hint - << " is missing a type annotation, " - << "or type annotation is not a pointer to primitive"; + TVM_FFI_THROW(InternalError) << "Let statement of variable " << let_var->name_hint + << " is missing a type annotation, " + << "or type annotation is not a pointer to primitive"; } } } @@ -1219,7 +1220,7 @@ class VectorTypeAccessChecker : public StmtExprVisitor { */ void OnArrayDeclaration(Var buffer, DataType element_dtype, PrimExpr extent, BufferVarInfo::DeclarationLocation declaration_location) { - ICHECK(info_map_.find(buffer.get()) == info_map_.end()) + TVM_FFI_ICHECK(info_map_.find(buffer.get()) == info_map_.end()) << "Array declaration of " << buffer->name_hint << " occurred multiple times."; if (element_dtype == DataType::Bool()) { @@ -1242,8 +1243,8 @@ class VectorTypeAccessChecker : public StmtExprVisitor { void OnArrayAccess(DataType value_dtype, const VarNode* buffer, const ffi::Array& indices, bool is_buffer_load) { auto it = info_map_.find(buffer); - ICHECK(it != info_map_.end()) << "Load/Store of buffer " << buffer->name_hint << " (" << buffer - << ") occurred before its declaration."; + TVM_FFI_ICHECK(it != info_map_.end()) << "Load/Store of buffer " << buffer->name_hint << " (" + << buffer << ") occurred before its declaration."; if (value_dtype.is_scalable_vector()) { // Scalable types are not currently supported in storage_rewrite. Scalable buffer @@ -1258,13 +1259,14 @@ class VectorTypeAccessChecker : public StmtExprVisitor { } if (var_info.element_dtype.is_handle()) { - ICHECK(allow_untyped_pointers_) << "Variable " << buffer->name_hint - << " was missing a type annotation in its declaration"; + TVM_FFI_ICHECK(allow_untyped_pointers_) + << "Variable " << buffer->name_hint + << " was missing a type annotation in its declaration"; var_info.element_dtype = value_dtype.element_of(); } for (int i = 0; i < static_cast(indices.size()) - 1; i++) { - ICHECK(indices[i].dtype().is_scalar()) + TVM_FFI_ICHECK(indices[i].dtype().is_scalar()) << "Only the last index of a buffer access may be a vector type."; } int index_lanes = indices.size() ? indices.back().dtype().lanes() : 1; @@ -1280,7 +1282,7 @@ class VectorTypeAccessChecker : public StmtExprVisitor { // pointer types (e.g. float16x4*). Once they do, this if statement should // instead be replaced by the below ICHECK_EQ. if (index_lanes * var_info.element_dtype.lanes() != value_dtype.lanes()) { - ICHECK_EQ(index_lanes, value_dtype.lanes()); + TVM_FFI_ICHECK_EQ(index_lanes, value_dtype.lanes()); lanes_used = 1; var_info.element_dtype = var_info.element_dtype.with_lanes(1); } @@ -1289,7 +1291,7 @@ class VectorTypeAccessChecker : public StmtExprVisitor { // See https://discuss.tvm.apache.org/t/pre-rfc-vectorized-tir-buffers/10615 // for discussion. - // ICHECK_EQ(index_lanes * var_info.element_dtype.lanes(), value_dtype.lanes()) + // TVM_FFI_ICHECK_EQ(index_lanes * var_info.element_dtype.lanes(), value_dtype.lanes()) // << "Attempting to retrieve " << value_dtype.lanes() << " lanes of data with " // << index_lanes << " indices into an array whose elements have " // << var_info.element_dtype.lanes() << " lanes. " @@ -1449,14 +1451,14 @@ class VectorTypeRewriter : public StmtExprMutator { int lanes = static_cast(Downcast(ramp_index->lanes)->value); PrimExpr new_index = ramp_index->base / make_const(ramp_index->base.dtype(), lanes); if (lanes != info.factor()) { - ICHECK(info.factor() && lanes % info.factor() == 0); + TVM_FFI_ICHECK(info.factor() && lanes % info.factor() == 0); int new_lanes = lanes / info.factor(); new_index = Ramp(new_index * new_lanes, ramp_index->stride, new_lanes, ramp_index->span); } indices.Set(indices.size() - 1, new_index); } else if (last_dim_index.dtype().lanes() == 1 && info.factor() > 1) { arith::ModularSet me = analyzer_.modular_set(last_dim_index); - ICHECK(me->coeff == 0 || info.factor() % me->coeff == 0); + TVM_FFI_ICHECK(me->coeff == 0 || info.factor() % me->coeff == 0); PrimExpr new_index = last_dim_index / make_const(last_dim_index.dtype(), info.factor()); shuffle_index = me->base % info.factor(); indices.Set(indices.size() - 1, new_index); @@ -1490,7 +1492,7 @@ class VectorTypeRewriter : public StmtExprMutator { Stmt VisitStmt_(const BufferStoreNode* op) final { auto node = Downcast(StmtExprMutator::VisitStmt_(op)); auto [modified, shuffle_index] = VisitBufferAccess(std::move(node)); - ICHECK(shuffle_index < 0); + TVM_FFI_ICHECK(shuffle_index < 0); return modified; } @@ -1590,7 +1592,7 @@ class VectorTypeRewriter : public StmtExprMutator { * @param func A pointer to the PrimFunc being modified. */ void Finalize(PrimFunc* func_ptr) { - ICHECK(func_ptr) << "Finalize expects a non-null pointer"; + TVM_FFI_ICHECK(func_ptr) << "Finalize expects a non-null pointer"; auto& func = *func_ptr; auto* n = func.CopyOnWrite(); @@ -1638,7 +1640,7 @@ class VectorTypeRewriter : public StmtExprMutator { int factor() const { int old_lanes = old_element_dtype.lanes(); int new_lanes = new_element_dtype.lanes(); - ICHECK_EQ(new_lanes % old_lanes, 0); + TVM_FFI_ICHECK_EQ(new_lanes % old_lanes, 0); return new_lanes / old_lanes; } }; diff --git a/src/tir/transform/unroll_loop.cc b/src/tir/transform/unroll_loop.cc index 7b92bad12d34..87a3e2be363c 100644 --- a/src/tir/transform/unroll_loop.cc +++ b/src/tir/transform/unroll_loop.cc @@ -130,7 +130,7 @@ class LoopUnroller : public StmtExprMutator { auto_unroll && (value * step_count_ <= auto_max_step_ || value <= auto_max_extent_); if (op->kind == ForKind::kUnrolled) { - ICHECK_GE(value, 0) << "Cannot unroll non-constant loop"; + TVM_FFI_ICHECK_GE(value, 0) << "Cannot unroll non-constant loop"; auto_unroll = true; } @@ -217,7 +217,7 @@ class LoopUnroller : public StmtExprMutator { Stmt Unroll(const ForNode* op) { int value = GetExtent(op); // For loop must have a constant integer extent - ICHECK_NE(value, -1) << "loop doesn't have a constant integer extent"; + TVM_FFI_ICHECK_NE(value, -1) << "loop doesn't have a constant integer extent"; if (value == 0) return Evaluate(0); Stmt body = op->body; ffi::Map vmap; diff --git a/src/tir/transform/unsupported_dtype_legalize.cc b/src/tir/transform/unsupported_dtype_legalize.cc index 0ae17b54846e..6ef5c0520534 100644 --- a/src/tir/transform/unsupported_dtype_legalize.cc +++ b/src/tir/transform/unsupported_dtype_legalize.cc @@ -68,7 +68,7 @@ class ComputeLegalizePlanner : public StmtExprVisitor { } for (Buffer buffer : drop_buffers) { auto it = buffer_remap_->find(buffer); - ICHECK(it != buffer_remap_->end()); + TVM_FFI_ICHECK(it != buffer_remap_->end()); buffer_remap_->erase(it); } } @@ -328,11 +328,12 @@ class ComputeLegalizer : public StmtExprMutator { if (value.dtype() != new_buf->dtype) { // this happens when buffer get rewritten to f32 // but values remain as fp8/bf16 - ICHECK(MatchDType(value->dtype)); + TVM_FFI_ICHECK(MatchDType(value->dtype)); value = DTypeConversion(value, new_buf->dtype.with_lanes(value.dtype().lanes())); } - ICHECK(!op->predicate.defined()) << "Predicated buffer store is not currently supported in " - "data type legalizer pass."; + TVM_FFI_ICHECK(!op->predicate.defined()) + << "Predicated buffer store is not currently supported in " + "data type legalizer pass."; return BufferStore(new_buf, value, indices); } } @@ -411,9 +412,9 @@ class ComputeLegalizer : public StmtExprMutator { if (it != var_remap_.end()) { Var remapped_var = it->second; auto* ptr = remapped_var->type_annotation.as(); - ICHECK(ptr); + TVM_FFI_ICHECK(ptr); auto* prim_type = ptr->element_type.as(); - ICHECK(prim_type); + TVM_FFI_ICHECK(prim_type); return Allocate(remapped_var, prim_type->dtype, op->extents, op->condition, op->body); } else { return ret; @@ -428,8 +429,9 @@ class ComputeLegalizer : public StmtExprMutator { if (new_buf.same_as(op->buffer)) { return ret; } else { - ICHECK(!op->predicate.defined()) << "Predicated buffer load is not currently supported in " - "data type legalizer pass."; + TVM_FFI_ICHECK(!op->predicate.defined()) + << "Predicated buffer load is not currently supported in " + "data type legalizer pass."; return BufferLoad(new_buf, op->indices); } } @@ -457,7 +459,7 @@ class ComputeLegalizer : public StmtExprMutator { */ PrimExpr CastTargetToDType(PrimExpr value, DataType dtype) { if (!value.dtype().is_float()) return value; - ICHECK_EQ(value.dtype(), this->promote_dtype_.with_lanes(value.dtype().lanes())); + TVM_FFI_ICHECK_EQ(value.dtype(), this->promote_dtype_.with_lanes(value.dtype().lanes())); return DTypeConversion(value, dtype); } @@ -505,7 +507,7 @@ class FP8ComputeLegalizer : public ComputeLegalizer { class StorageLegalizer : public StmtExprMutator { public: PrimFunc Legalize(PrimFunc func) { - ICHECK_EQ(func->buffer_map.size(), 0) << "This pass must be called after MakePackedAPI"; + TVM_FFI_ICHECK_EQ(func->buffer_map.size(), 0) << "This pass must be called after MakePackedAPI"; auto* n = func.CopyOnWrite(); n->params = n->params.Map([this](Var var) { return this->RemapVarDef(var); }); n->body = this->VisitStmt(std::move(n->body)); @@ -589,10 +591,11 @@ class StorageLegalizer : public StmtExprMutator { return ffi::GetRef(op); } else { if (MatchDType(op->value.dtype())) { - ICHECK(new_buf->dtype.is_uint()); + TVM_FFI_ICHECK(new_buf->dtype.is_uint()); } - ICHECK(!op->predicate.defined()) << "Predicated buffer store is not currently supported in " - "data type legalizer pass."; + TVM_FFI_ICHECK(!op->predicate.defined()) + << "Predicated buffer store is not currently supported in " + "data type legalizer pass."; return BufferStore(new_buf, value, indices); } } @@ -622,8 +625,9 @@ class StorageLegalizer : public StmtExprMutator { if (new_buf.same_as(op->buffer)) { return ret; } else { - ICHECK(!op->predicate.defined()) << "Predicated buffer load is not currently supported in " - "data type legalizer pass."; + TVM_FFI_ICHECK(!op->predicate.defined()) + << "Predicated buffer load is not currently supported in " + "data type legalizer pass."; return BufferLoad(new_buf, op->indices); } } @@ -695,7 +699,7 @@ class StorageLegalizer : public StmtExprMutator { buf->data_alignment, buf->offset_factor, buf->buffer_type, buf->axis_separators, buf->span); } else { - ICHECK(!MatchDType(buf->dtype)) << "Cannot find var remap for " << buf; + TVM_FFI_ICHECK(!MatchDType(buf->dtype)) << "Cannot find var remap for " << buf; } buffer_remap_[buf] = new_buf; diff --git a/src/tir/transform/update_pointer_storage_scope.cc b/src/tir/transform/update_pointer_storage_scope.cc index e12ab9696a99..39360b559fd1 100644 --- a/src/tir/transform/update_pointer_storage_scope.cc +++ b/src/tir/transform/update_pointer_storage_scope.cc @@ -39,7 +39,7 @@ namespace tir { Var WithStorageScope(const VarNode* buffer_var, ffi::String storage_scope) { auto* ptr_type = buffer_var->type_annotation.as(); - ICHECK(ptr_type) << "The provided variable is not of pointer type"; + TVM_FFI_ICHECK(ptr_type) << "The provided variable is not of pointer type"; return Var(buffer_var->name_hint, PointerType(ptr_type->element_type, storage_scope), buffer_var->span); } diff --git a/src/tir/transform/vectorize_loop.cc b/src/tir/transform/vectorize_loop.cc index 331f556a9442..2e8f1811996a 100644 --- a/src/tir/transform/vectorize_loop.cc +++ b/src/tir/transform/vectorize_loop.cc @@ -58,7 +58,7 @@ inline PrimExpr BroadcastTo(PrimExpr e, int lanes, bool is_scalable) { return e; if (const BroadcastNode* op = e.as()) { - ICHECK(op->dtype.is_scalable_vector() == is_scalable) + TVM_FFI_ICHECK(op->dtype.is_scalable_vector() == is_scalable) << "Can't broadcast between scalable and fixed length vectors."; int e_lanes = op->dtype.get_lanes_or_vscale_factor(); @@ -67,10 +67,9 @@ inline PrimExpr BroadcastTo(PrimExpr e, int lanes, bool is_scalable) { } } - ICHECK(e.dtype().is_scalar()) << "Cannot broadcast lanes=" - << e.dtype().get_lanes_or_vscale_factor() - << " is_scalable=" << e.dtype().is_scalable_vector() << " to " - << lanes; + TVM_FFI_ICHECK(e.dtype().is_scalar()) + << "Cannot broadcast lanes=" << e.dtype().get_lanes_or_vscale_factor() + << " is_scalable=" << e.dtype().is_scalable_vector() << " to " << lanes; return Broadcast(e, CreateNewLanes(is_scalable, lanes)); } @@ -300,7 +299,7 @@ class Vectorizer : public StmtMutator, public ExprFunctorVisitExpr(op->base); PrimExpr stride = this->VisitExpr(op->stride); - ICHECK(!base.dtype().is_scalable_vector()) + TVM_FFI_ICHECK(!base.dtype().is_scalable_vector()) << "Creating scalable vectors from existing vectors is not supported."; - ICHECK(!stride.dtype().is_scalable_vector()) + TVM_FFI_ICHECK(!stride.dtype().is_scalable_vector()) << "Ramp stride with scalable dtype is not supported"; if (base.dtype().is_fixed_length_vector() && stride.dtype().is_scalar()) { - ICHECK(op->lanes->IsInstance()) + TVM_FFI_ICHECK(op->lanes->IsInstance()) << "Vectorizing over existing scalable vectors is not supported."; const RampNode* base_ramp = base.as(); int op_lanes = static_cast(Downcast(op->lanes)->value); @@ -497,7 +496,7 @@ class Vectorizer : public StmtMutator, public ExprFunctorop.same_as(builtin::reinterpret())); + TVM_FFI_ICHECK(op->op.same_as(builtin::reinterpret())); PrimExpr value = this->VisitExpr(op->args[0]); if (value.same_as(op->args[0])) { return ffi::GetRef(op); @@ -526,7 +525,7 @@ class Vectorizer : public StmtMutator, public ExprFunctortype_annotation.as() ->element_type.as() ->dtype; - ICHECK(lane * dtype.bits() <= op->args[4].as()->value) + TVM_FFI_ICHECK(lane * dtype.bits() <= op->args[4].as()->value) << "Expected Data to be Read is lesser than or equal to Texture Load length"; auto new_args = op->args; @@ -543,7 +542,7 @@ class Vectorizer : public StmtMutator, public ExprFunctortype_annotation.as() ->element_type.as() ->dtype; - ICHECK(lane * dtype.bits() == op->args[4].as()->value) + TVM_FFI_ICHECK(lane * dtype.bits() == op->args[4].as()->value) << "Expected Data to be Written equal to Texture Store length"; ffi::Array new_args{op->args[0], op->args[1], op->args[2], op->args[3], op->args[4], mutated_value[0]}; @@ -625,7 +624,7 @@ class Vectorizer : public StmtMutator, public ExprFunctorvar); if (it != let_binding_.end()) { - ICHECK(deep_equal_(it->second, value)) + TVM_FFI_ICHECK(deep_equal_(it->second, value)) << "Let cannot bind the same var to two different values"; } if (value.dtype().get_lanes_or_vscale_factor() != @@ -644,7 +643,7 @@ class Vectorizer : public StmtMutator, public ExprFunctorvectors.size() == 1 && op->indices.size() == 1) + TVM_FFI_ICHECK(op->vectors.size() == 1 && op->indices.size() == 1) << "Cannot vectorize ShuffleNode with multiple vectors or indices: the vector size is " << op->vectors.size() << " and the index size is " << op->indices.size(); int lane_vectors = 0; @@ -689,7 +688,7 @@ class Vectorizer : public StmtMutator, public ExprFunctorvectors[0], {{var_, tvm::IntImm(var_->dtype, 0)}}); @@ -715,7 +714,7 @@ class Vectorizer : public StmtMutator, public ExprFunctorVisitExpr(op->value); if (!indices.same_as(op->indices) || !value.same_as(op->value)) { - ICHECK(!op->buffer->dtype.is_scalable_vector()) + TVM_FFI_ICHECK(!op->buffer->dtype.is_scalable_vector()) << "Vectorizing over scalable buffer elements is not supported in vectorizer."; // How many lanes of indexing are present in the index and // buffer element type, excluding the last index. @@ -723,7 +722,8 @@ class Vectorizer : public StmtMutator, public ExprFunctorbuffer->name << ", cannot produce " << total_lanes << " lanes of storage location by changing the last index."; int last_index_lanes = total_lanes / other_index_lanes; @@ -760,8 +760,8 @@ class Vectorizer : public StmtMutator, public ExprFunctorkind == ForKind::kVectorized) { LOG(WARNING) << "Detect vectorize inside vectorized loop, ignoring..."; } - ICHECK(is_zero(op->min)); - ICHECK(!op->extent.dtype().is_scalable_or_fixed_length_vector()); + TVM_FFI_ICHECK(is_zero(op->min)); + TVM_FFI_ICHECK(!op->extent.dtype().is_scalable_or_fixed_length_vector()); PrimExpr extent = this->VisitExpr(op->extent); if (extent.dtype().is_scalable_or_fixed_length_vector()) { return Scalarize(ffi::GetRef(op)); @@ -778,7 +778,7 @@ class Vectorizer : public StmtMutator, public ExprFunctorcondition.dtype().is_scalable_or_fixed_length_vector()); + TVM_FFI_ICHECK(!op->condition.dtype().is_scalable_or_fixed_length_vector()); PrimExpr condition = this->VisitExpr(op->condition); // need scalarize can be marked as true during visit of condition bool cond_need_scalarize = false; @@ -813,7 +813,7 @@ class Vectorizer : public StmtMutator, public ExprFunctor(op)); } - ICHECK(!let_binding_.count(op->var)) << "SSA violation, a single var is binded twice"; + TVM_FFI_ICHECK(!let_binding_.count(op->var)) << "SSA violation, a single var is binded twice"; let_binding_[op->var] = value; if (value.dtype().get_lanes_or_vscale_factor() != @@ -985,10 +985,10 @@ class LoopVectorizer : public StmtMutator { if (!extent_as_int || extent_as_int->value < 1) { bool is_scalable_expr = CheckContains::ExprContains(op->extent, arith::IsVScaleCall); - ICHECK(is_scalable_expr && arith::TargetHasVLA(target_)) + TVM_FFI_ICHECK(is_scalable_expr && arith::TargetHasVLA(target_)) << "Failed to vectorize loop with extent " << op->extent << " for target " << target_; } - ICHECK(is_zero(op->min)); + TVM_FFI_ICHECK(is_zero(op->min)); return Vectorizer(op->loop_var, op->extent, target_)(op->body); } else { return StmtMutator::VisitStmt_(op); diff --git a/src/topi/einsum.cc b/src/topi/einsum.cc index 1b02a80f6387..53f604defd96 100644 --- a/src/topi/einsum.cc +++ b/src/topi/einsum.cc @@ -44,8 +44,8 @@ EinsumEquation EinsumEquation::FromString(const std::string& equation) { break; case '-': // Arrow - CHECK(!has_arrow) << "Equation can only have one arrow"; - CHECK(i + 1 < n && equation[i + 1] == '>') + TVM_FFI_ICHECK(!has_arrow) << "Equation can only have one arrow"; + TVM_FFI_ICHECK(i + 1 < n && equation[i + 1] == '>') << "Cannot parse the Einsum equation: invalid arrow"; i++; has_arrow = true; @@ -58,8 +58,8 @@ EinsumEquation EinsumEquation::FromString(const std::string& equation) { break; case '.': // Ellipsis - CHECK(!has_ellipsis) << "Ellipsis can only appear once for each input and output"; - CHECK(i + 2 < n && equation[i + 1] == '.' && equation[i + 2] == '.') + TVM_FFI_ICHECK(!has_ellipsis) << "Ellipsis can only appear once for each input and output"; + TVM_FFI_ICHECK(i + 2 < n && equation[i + 1] == '.' && equation[i + 2] == '.') << "Cannot parse the Einsum equation: invalid ellipsis"; current.push_back(kEllipsis); has_ellipsis = true; @@ -67,8 +67,9 @@ EinsumEquation EinsumEquation::FromString(const std::string& equation) { break; default: // Default case: current character is a subscript label - CHECK(std::isalpha(equation[i])) << "Cannot parse the Einsum equation: invalid character " - << equation[i] << " in equation " << equation; + TVM_FFI_ICHECK(std::isalpha(equation[i])) + << "Cannot parse the Einsum equation: invalid character " << equation[i] + << " in equation " << equation; current.emplace_back(equation[i]); break; } @@ -110,7 +111,7 @@ PrimExpr GetBroadcastedExtent(const PrimExpr& extent1, const PrimExpr& extent2) } else if (extent1_imm->value == 1 || extent2_imm->value == 1) { return Integer(std::max(extent1_imm->value, extent2_imm->value)); } - LOG(FATAL) << "Cannot broadcast extents " << extent1 << " and " << extent2; + TVM_FFI_THROW(InternalError) << "Cannot broadcast extents " << extent1 << " and " << extent2; throw; } else if (extent1_imm != nullptr) { return extent2; @@ -147,7 +148,7 @@ class EinsumBuilder { * \return The inferred shape of the output */ ffi::Array InferShape() { - CHECK_EQ(equation_.inputs.size(), input_shapes_.size()) + TVM_FFI_ICHECK_EQ(equation_.inputs.size(), input_shapes_.size()) << "Number of operands does not match the " "equation"; @@ -179,7 +180,7 @@ class EinsumBuilder { } } } - ICHECK_EQ(current_dim, input_shape.size()); + TVM_FFI_ICHECK_EQ(current_dim, input_shape.size()); } // Step 2: Infer the shape of the ellipsis if exists @@ -253,7 +254,7 @@ class EinsumBuilder { label_to_index->emplace(label, indices[i++]); } } - ICHECK_EQ(i, indices.size()); + TVM_FFI_ICHECK_EQ(i, indices.size()); } /*! @@ -321,8 +322,8 @@ class EinsumBuilder { label_to_extent_.at(label))); } } - ICHECK_EQ(i, input_shape.size()); - ICHECK_EQ(indices.size(), input_shape.size()); + TVM_FFI_ICHECK_EQ(i, input_shape.size()); + TVM_FFI_ICHECK_EQ(indices.size(), input_shape.size()); return indices; } diff --git a/src/topi/transform.cc b/src/topi/transform.cc index d9545e637405..5e2ffd4cbd9a 100644 --- a/src/topi/transform.cc +++ b/src/topi/transform.cc @@ -116,7 +116,7 @@ TVM_FFI_STATIC_INIT_BLOCK() { int batch_dims = args[2].cast(); *rv = take(args[0].cast(), args[1].cast(), batch_dims, mode); } else { - ICHECK_EQ(args.size(), 5) << "topi.take expects 4 or 5 arguments"; + TVM_FFI_ICHECK_EQ(args.size(), 5) << "topi.take expects 4 or 5 arguments"; int batch_dims = args[2].cast(); int axis = args[3].cast(); auto mode = args[4].cast(); @@ -192,7 +192,7 @@ TVM_FFI_STATIC_INIT_BLOCK() { args[2].cast(), args[3].cast()); break; default: - ICHECK(0) << "topi.matmul expects 2, 3 or 4 arguments"; + TVM_FFI_ICHECK(0) << "topi.matmul expects 2, 3 or 4 arguments"; } }) .def_packed("topi.tensordot", diff --git a/tests/cpp-runtime/hexagon/hexagon_buffer_tests.cc b/tests/cpp-runtime/hexagon/hexagon_buffer_tests.cc index febf484f8161..d7fb9cbb36f0 100644 --- a/tests/cpp-runtime/hexagon/hexagon_buffer_tests.cc +++ b/tests/cpp-runtime/hexagon/hexagon_buffer_tests.cc @@ -47,7 +47,7 @@ TEST(HexagonBuffer, vtcm_scope) { TEST(HexagonBuffer, invalid_scope) { ffi::Optional scope(ffi::String("invalid")); - EXPECT_THROW(HexagonBuffer hb(8 /* nbytes */, 8 /* alignment */, scope), InternalError); + EXPECT_THROW(HexagonBuffer hb(8 /* nbytes */, 8 /* alignment */, scope), tvm::ffi::Error); } TEST(HexagonBuffer, micro_copies_corresponding_regions) { @@ -199,13 +199,13 @@ TEST(HexagonBuffer, micro_copies_invalid_size) { { BufferSet src(src_ptr.data(), 1, 16); BufferSet dest(dest_ptr.data(), 2, 16); - EXPECT_THROW(BufferSet::MemoryCopies(dest, src, 24), InternalError); + EXPECT_THROW(BufferSet::MemoryCopies(dest, src, 24), tvm::ffi::Error); } { BufferSet src(src_ptr.data(), 2, 16); BufferSet dest(dest_ptr.data(), 1, 16); - EXPECT_THROW(BufferSet::MemoryCopies(dest, src, 24), InternalError); + EXPECT_THROW(BufferSet::MemoryCopies(dest, src, 24), tvm::ffi::Error); } } @@ -287,7 +287,7 @@ TEST(HexagonBuffer, copy_from_invalid_size) { // HexagonBuffer too small HexagonBuffer toosmall(4 /* nbytes */, 8 /* alignment */, scope); - EXPECT_THROW(toosmall.CopyFrom(data.data(), data.size()), InternalError); + EXPECT_THROW(toosmall.CopyFrom(data.data(), data.size()), tvm::ffi::Error); } TEST(HexagonBuffer, copy_from_smaller_size) { @@ -397,12 +397,12 @@ TEST(HexagonBuffer, nd_copy_from_nd_invalid_size) { HexagonBuffer hb2d(2 /* ndim */, 4 /* nbytes */, 8 /* alignment */, scope); HexagonBuffer toosbig1d(16 /* nbytes */, 16 /* alignment */, scope); - EXPECT_THROW(hb1d.CopyFrom(toosbig1d, 16), InternalError); - EXPECT_THROW(hb2d.CopyFrom(toosbig1d, 16), InternalError); + EXPECT_THROW(hb1d.CopyFrom(toosbig1d, 16), tvm::ffi::Error); + EXPECT_THROW(hb2d.CopyFrom(toosbig1d, 16), tvm::ffi::Error); HexagonBuffer toobig2d(2 /* ndim */, 16 /* nbytes */, 16 /* alignment */, scope); - EXPECT_THROW(hb1d.CopyFrom(toobig2d, 32), InternalError); - EXPECT_THROW(hb2d.CopyFrom(toobig2d, 32), InternalError); + EXPECT_THROW(hb1d.CopyFrom(toobig2d, 32), tvm::ffi::Error); + EXPECT_THROW(hb2d.CopyFrom(toobig2d, 32), tvm::ffi::Error); } TEST(HexagonBuffer, nd_copy_from_nd_smaller_size) { diff --git a/tests/cpp-runtime/hexagon/hexagon_device_api_tests.cc b/tests/cpp-runtime/hexagon/hexagon_device_api_tests.cc index 9c74521091aa..edcbb97bbe5a 100644 --- a/tests/cpp-runtime/hexagon/hexagon_device_api_tests.cc +++ b/tests/cpp-runtime/hexagon/hexagon_device_api_tests.cc @@ -53,62 +53,62 @@ class HexagonDeviceAPITest : public ::testing::Test { ffi::Optional global_vtcm_scope = ffi::String("global.vtcm"); }; -TEST_F(HexagonDeviceAPITest, global) { CHECK(hexapi != nullptr); } +TEST_F(HexagonDeviceAPITest, global) { TVM_FFI_ICHECK(hexapi != nullptr); } TEST_F(HexagonDeviceAPITest, alloc_free_cpu) { void* buf = hexapi->AllocDataSpace(cpu_dev, nbytes, alignment, int8); - CHECK(buf != nullptr); + TVM_FFI_ICHECK(buf != nullptr); hexapi->FreeDataSpace(cpu_dev, buf); } TEST_F(HexagonDeviceAPITest, alloc_free_hex) { void* buf = hexapi->AllocDataSpace(hex_dev, nbytes, alignment, int8); - CHECK(buf != nullptr); + TVM_FFI_ICHECK(buf != nullptr); hexapi->FreeDataSpace(hex_dev, buf); } TEST_F(HexagonDeviceAPITest, alloc_errors) { // invalid device - EXPECT_THROW(hexapi->AllocDataSpace(invalid_dev, nbytes, alignment, int8), InternalError); + EXPECT_THROW(hexapi->AllocDataSpace(invalid_dev, nbytes, alignment, int8), tvm::ffi::Error); // 0 size - EXPECT_THROW(hexapi->AllocDataSpace(hex_dev, 0, alignment, int8), InternalError); + EXPECT_THROW(hexapi->AllocDataSpace(hex_dev, 0, alignment, int8), tvm::ffi::Error); // 0 alignment - EXPECT_THROW(hexapi->AllocDataSpace(hex_dev, nbytes, 0, int8), InternalError); + EXPECT_THROW(hexapi->AllocDataSpace(hex_dev, nbytes, 0, int8), tvm::ffi::Error); } TEST_F(HexagonDeviceAPITest, free_errors) { void* buf = hexapi->AllocDataSpace(hex_dev, nbytes, alignment, int8); // invalid device - EXPECT_THROW(hexapi->FreeDataSpace(invalid_dev, buf), InternalError); + EXPECT_THROW(hexapi->FreeDataSpace(invalid_dev, buf), tvm::ffi::Error); // invalid pointer - EXPECT_THROW(hexapi->FreeDataSpace(hex_dev, &buf), InternalError); + EXPECT_THROW(hexapi->FreeDataSpace(hex_dev, &buf), tvm::ffi::Error); // nullptr - EXPECT_THROW(hexapi->FreeDataSpace(hex_dev, nullptr), InternalError); + EXPECT_THROW(hexapi->FreeDataSpace(hex_dev, nullptr), tvm::ffi::Error); // double free hexapi->FreeDataSpace(hex_dev, buf); - EXPECT_THROW(hexapi->FreeDataSpace(hex_dev, buf), InternalError); + EXPECT_THROW(hexapi->FreeDataSpace(hex_dev, buf), tvm::ffi::Error); } TEST_F(HexagonDeviceAPITest, allocnd_free_cpu) { void* buf = hexapi->AllocDataSpace(cpu_dev, 3, shape3d, int8, global_scope); - CHECK(buf != nullptr); + TVM_FFI_ICHECK(buf != nullptr); hexapi->FreeDataSpace(cpu_dev, buf); } TEST_F(HexagonDeviceAPITest, allocnd_free_hex) { void* buf = hexapi->AllocDataSpace(hex_dev, 3, shape3d, int8, global_scope); - CHECK(buf != nullptr); + TVM_FFI_ICHECK(buf != nullptr); hexapi->FreeDataSpace(hex_dev, buf); } TEST_F(HexagonDeviceAPITest, allocnd_free_hex_vtcm) { void* buf1d = hexapi->AllocDataSpace(hex_dev, 1, shape1d, int8, global_vtcm_scope); - CHECK(buf1d != nullptr); + TVM_FFI_ICHECK(buf1d != nullptr); hexapi->FreeDataSpace(hex_dev, buf1d); void* buf2d = hexapi->AllocDataSpace(hex_dev, 2, shape2d, int8, global_vtcm_scope); - CHECK(buf2d != nullptr); + TVM_FFI_ICHECK(buf2d != nullptr); hexapi->FreeDataSpace(hex_dev, buf2d); } @@ -118,27 +118,30 @@ TEST_F(HexagonDeviceAPITest, allocnd_erros) { InternalError); // Hexagon VTCM allocations must have 0 (scalar) 1 or 2 dimensions - EXPECT_THROW(hexapi->AllocDataSpace(hex_dev, 3, shape3d, int8, global_vtcm_scope), InternalError); + EXPECT_THROW(hexapi->AllocDataSpace(hex_dev, 3, shape3d, int8, global_vtcm_scope), + tvm::ffi::Error); // null shape - EXPECT_THROW(hexapi->AllocDataSpace(hex_dev, 2, nullptr, int8, global_vtcm_scope), InternalError); + EXPECT_THROW(hexapi->AllocDataSpace(hex_dev, 2, nullptr, int8, global_vtcm_scope), + tvm::ffi::Error); // null shape - EXPECT_THROW(hexapi->AllocDataSpace(hex_dev, 2, shape2d, int8, invalid_scope), InternalError); + EXPECT_THROW(hexapi->AllocDataSpace(hex_dev, 2, shape2d, int8, invalid_scope), tvm::ffi::Error); // cpu & global.vtcm scope - EXPECT_THROW(hexapi->AllocDataSpace(cpu_dev, 2, shape2d, int8, global_vtcm_scope), InternalError); + EXPECT_THROW(hexapi->AllocDataSpace(cpu_dev, 2, shape2d, int8, global_vtcm_scope), + tvm::ffi::Error); } TEST_F(HexagonDeviceAPITest, alloc_scalar) { void* cpuscalar = hexapi->AllocDataSpace(cpu_dev, 0, new int64_t, int8, global_scope); - CHECK(cpuscalar != nullptr); + TVM_FFI_ICHECK(cpuscalar != nullptr); void* hexscalar = hexapi->AllocDataSpace(hex_dev, 0, new int64_t, int8, global_vtcm_scope); - CHECK(hexscalar != nullptr); + TVM_FFI_ICHECK(hexscalar != nullptr); hexscalar = hexapi->AllocDataSpace(hex_dev, 0, nullptr, int8, global_vtcm_scope); - CHECK(hexscalar != nullptr); + TVM_FFI_ICHECK(hexscalar != nullptr); } // alloc and free of the same buffer on different devices should throw @@ -148,31 +151,31 @@ TEST_F(HexagonDeviceAPITest, alloc_scalar) { // TODO(HWE): Re-enable or delete this test case once we land on device type strategy TEST_F(HexagonDeviceAPITest, DISABLED_alloc_free_diff_dev) { void* buf = hexapi->AllocDataSpace(hex_dev, nbytes, alignment, int8); - CHECK(buf != nullptr); - EXPECT_THROW(hexapi->FreeDataSpace(cpu_dev, buf), InternalError); + TVM_FFI_ICHECK(buf != nullptr); + EXPECT_THROW(hexapi->FreeDataSpace(cpu_dev, buf), tvm::ffi::Error); } // Ensure runtime buffer manager is properly configured and destroyed // in Acquire/Release TEST_F(HexagonDeviceAPITest, runtime_buffer_manager) { hexapi->ReleaseResources(); - EXPECT_THROW(hexapi->AllocDataSpace(hex_dev, nbytes, alignment, int8), InternalError); + EXPECT_THROW(hexapi->AllocDataSpace(hex_dev, nbytes, alignment, int8), tvm::ffi::Error); hexapi->AcquireResources(); void* runtime_buf = hexapi->AllocDataSpace(hex_dev, nbytes, alignment, int8); - CHECK(runtime_buf != nullptr); + TVM_FFI_ICHECK(runtime_buf != nullptr); hexapi->ReleaseResources(); hexapi->FreeDataSpace(hex_dev, runtime_buf); hexapi->AcquireResources(); - EXPECT_THROW(hexapi->FreeDataSpace(hex_dev, runtime_buf), InternalError); + EXPECT_THROW(hexapi->FreeDataSpace(hex_dev, runtime_buf), tvm::ffi::Error); } // Ensure thread manager is properly configured and destroyed // in Acquire/Release TEST_F(HexagonDeviceAPITest, thread_manager) { HexagonThreadManager* threads = hexapi->ThreadManager(); - CHECK(threads != nullptr); + TVM_FFI_ICHECK(threads != nullptr); hexapi->ReleaseResources(); - EXPECT_THROW(hexapi->ThreadManager(), InternalError); + EXPECT_THROW(hexapi->ThreadManager(), tvm::ffi::Error); hexapi->AcquireResources(); } @@ -180,9 +183,9 @@ TEST_F(HexagonDeviceAPITest, thread_manager) { // in Acquire/Release TEST_F(HexagonDeviceAPITest, user_dma) { HexagonUserDMA* user_dma = hexapi->UserDMA(); - CHECK(user_dma != nullptr); + TVM_FFI_ICHECK(user_dma != nullptr); hexapi->ReleaseResources(); - EXPECT_THROW(hexapi->UserDMA(), InternalError); + EXPECT_THROW(hexapi->UserDMA(), tvm::ffi::Error); hexapi->AcquireResources(); } @@ -190,8 +193,8 @@ TEST_F(HexagonDeviceAPITest, user_dma) { // in Acquire/Release TEST_F(HexagonDeviceAPITest, vtcm_pool) { HexagonVtcmPool* vtcm_pool = hexapi->VtcmPool(); - CHECK(vtcm_pool != nullptr); + TVM_FFI_ICHECK(vtcm_pool != nullptr); hexapi->ReleaseResources(); - EXPECT_THROW(hexapi->VtcmPool(), InternalError); + EXPECT_THROW(hexapi->VtcmPool(), tvm::ffi::Error); hexapi->AcquireResources(); } diff --git a/tests/cpp-runtime/hexagon/hexagon_thread_manager_tests.cc b/tests/cpp-runtime/hexagon/hexagon_thread_manager_tests.cc index af29a428bc69..03bf07973118 100644 --- a/tests/cpp-runtime/hexagon/hexagon_thread_manager_tests.cc +++ b/tests/cpp-runtime/hexagon/hexagon_thread_manager_tests.cc @@ -44,19 +44,19 @@ class HexagonThreadManagerTest : public ::testing::Test { TEST_F(HexagonThreadManagerTest, ctor_edge_cases) { // zero threads - ASSERT_THROW(HexagonThreadManager(0, stack_size, pipe_size), InternalError); + ASSERT_THROW(HexagonThreadManager(0, stack_size, pipe_size), tvm::ffi::Error); // too many threads - ASSERT_THROW(HexagonThreadManager(0x10000000, stack_size, pipe_size), InternalError); + ASSERT_THROW(HexagonThreadManager(0x10000000, stack_size, pipe_size), tvm::ffi::Error); // stack too small - ASSERT_THROW(HexagonThreadManager(6, 0, pipe_size), InternalError); + ASSERT_THROW(HexagonThreadManager(6, 0, pipe_size), tvm::ffi::Error); // stack too big - ASSERT_THROW(HexagonThreadManager(6, 0x10000000, pipe_size), InternalError); + ASSERT_THROW(HexagonThreadManager(6, 0x10000000, pipe_size), tvm::ffi::Error); // pipe too small - ASSERT_THROW(HexagonThreadManager(6, stack_size, 9), InternalError); + ASSERT_THROW(HexagonThreadManager(6, stack_size, 9), tvm::ffi::Error); // pipe too big - ASSERT_THROW(HexagonThreadManager(6, stack_size, 0x10000000), InternalError); + ASSERT_THROW(HexagonThreadManager(6, stack_size, 0x10000000), tvm::ffi::Error); // hw resources count doesn't match thread count - ASSERT_THROW(HexagonThreadManager(6, stack_size, pipe_size, {DMA_0}), InternalError); + ASSERT_THROW(HexagonThreadManager(6, stack_size, pipe_size, {DMA_0}), tvm::ffi::Error); // no more than one of each hw resource may be specified ASSERT_THROW(HexagonThreadManager(4, stack_size, pipe_size, {DMA_0, HTP_0, HVX_0, HVX_0}), InternalError); @@ -70,8 +70,8 @@ TEST_F(HexagonThreadManagerTest, ctor_edge_cases) { } TEST_F(HexagonThreadManagerTest, init) { - CHECK(htm != nullptr); - CHECK_EQ(streams.size(), threads); + TVM_FFI_ICHECK(htm != nullptr); + TVM_FFI_ICHECK_EQ(streams.size(), threads); } void get_the_answer(void* answer) { *reinterpret_cast(answer) = 42; } @@ -80,13 +80,13 @@ TEST_F(HexagonThreadManagerTest, dispatch) { htm->Dispatch(streams[0], get_the_answer, &answer); htm->Start(); htm->WaitOnThreads(); - CHECK_EQ(answer, 42); + TVM_FFI_ICHECK_EQ(answer, 42); } TEST_F(HexagonThreadManagerTest, dispatch_wait) { htm->Dispatch(streams[0], get_the_answer, &answer); htm->WaitOnThreads(); - CHECK_EQ(answer, 42); + TVM_FFI_ICHECK_EQ(answer, 42); } TEST_F(HexagonThreadManagerTest, wait_signal) { @@ -94,7 +94,7 @@ TEST_F(HexagonThreadManagerTest, wait_signal) { htm->Signal(streams[1], 0); htm->Dispatch(streams[0], get_the_answer, &answer); htm->WaitOnThreads(); - CHECK_EQ(answer, 42); + TVM_FFI_ICHECK_EQ(answer, 42); } TEST_F(HexagonThreadManagerTest, re_signal) { @@ -103,7 +103,7 @@ TEST_F(HexagonThreadManagerTest, re_signal) { htm->Signal(streams[1], 0); htm->Dispatch(streams[0], get_the_answer, &answer); htm->WaitOnThreads(); - CHECK_EQ(answer, 42); + TVM_FFI_ICHECK_EQ(answer, 42); } TEST_F(HexagonThreadManagerTest, re_wait) { @@ -112,7 +112,7 @@ TEST_F(HexagonThreadManagerTest, re_wait) { htm->Wait(streams[0], 0); htm->Dispatch(streams[0], get_the_answer, &answer); htm->WaitOnThreads(); - CHECK_EQ(answer, 42); + TVM_FFI_ICHECK_EQ(answer, 42); } TEST_F(HexagonThreadManagerTest, wait_signal_x2) { @@ -122,7 +122,7 @@ TEST_F(HexagonThreadManagerTest, wait_signal_x2) { htm->Signal(streams[1], 1); htm->Dispatch(streams[0], get_the_answer, &answer); htm->WaitOnThreads(); - CHECK_EQ(answer, 42); + TVM_FFI_ICHECK_EQ(answer, 42); } TEST_F(HexagonThreadManagerTest, signal_wait) { @@ -130,21 +130,21 @@ TEST_F(HexagonThreadManagerTest, signal_wait) { htm->Wait(streams[0], 0); htm->Dispatch(streams[0], get_the_answer, &answer); htm->WaitOnThreads(); - CHECK_EQ(answer, 42); + TVM_FFI_ICHECK_EQ(answer, 42); } TEST_F(HexagonThreadManagerTest, sync_from_to) { htm->SyncFromTo(streams[1], streams[0]); htm->Dispatch(streams[0], get_the_answer, &answer); htm->WaitOnThreads(); - CHECK_EQ(answer, 42); + TVM_FFI_ICHECK_EQ(answer, 42); } TEST_F(HexagonThreadManagerTest, sync_from_to_self) { htm->SyncFromTo(streams[0], streams[0]); htm->Dispatch(streams[0], get_the_answer, &answer); htm->WaitOnThreads(); - CHECK_EQ(answer, 42); + TVM_FFI_ICHECK_EQ(answer, 42); } TEST_F(HexagonThreadManagerTest, sync_from_to_x2) { @@ -152,7 +152,7 @@ TEST_F(HexagonThreadManagerTest, sync_from_to_x2) { htm->SyncFromTo(streams[1], streams[0]); htm->Dispatch(streams[0], get_the_answer, &answer); htm->WaitOnThreads(); - CHECK_EQ(answer, 42); + TVM_FFI_ICHECK_EQ(answer, 42); } TEST_F(HexagonThreadManagerTest, sync_from_to_all) { @@ -163,7 +163,7 @@ TEST_F(HexagonThreadManagerTest, sync_from_to_all) { htm->SyncFromTo(streams[1], streams[0]); htm->Dispatch(streams[0], get_the_answer, &answer); htm->WaitOnThreads(); - CHECK_EQ(answer, 42); + TVM_FFI_ICHECK_EQ(answer, 42); } TEST_F(HexagonThreadManagerTest, pipe_fill) { @@ -172,7 +172,7 @@ TEST_F(HexagonThreadManagerTest, pipe_fill) { htm->Dispatch(streams[0], get_the_answer, &answer); } htm->WaitOnThreads(); - CHECK_EQ(answer, 42); + TVM_FFI_ICHECK_EQ(answer, 42); } // TODO(HWE): Create a temporary thread manager with a smaller pipe for this test @@ -183,7 +183,7 @@ TEST_F(HexagonThreadManagerTest, pipe_overflow) { } // overflow the pipe bool space = htm->Dispatch(streams[0], get_the_answer, &answer); - CHECK_EQ(space, false); + TVM_FFI_ICHECK_EQ(space, false); } void increment(void* voidptr) { @@ -204,7 +204,7 @@ TEST_F(HexagonThreadManagerTest, producer_consumer) { htm->SyncFromTo(streams[1], streams[0]); htm->Dispatch(streams[0], increment, &answer); htm->WaitOnThreads(); - CHECK_EQ(answer, 6); + TVM_FFI_ICHECK_EQ(answer, 6); } TEST_F(HexagonThreadManagerTest, producer_consumer_signal_wait) { @@ -226,7 +226,7 @@ TEST_F(HexagonThreadManagerTest, producer_consumer_signal_wait) { htm->Signal(streams[1], 0); htm->Dispatch(streams[0], increment, &answer); htm->WaitOnThreads(); - CHECK_EQ(answer, 6); + TVM_FFI_ICHECK_EQ(answer, 6); } struct ToAppend { @@ -267,7 +267,7 @@ TEST_F(HexagonThreadManagerTest, thread_order) { htm->Dispatch(streams[5], append, &cmd5); htm->WaitOnThreads(); for (int i = 0; i < threads; ++i) { - CHECK_EQ(arr[i], i); + TVM_FFI_ICHECK_EQ(arr[i], i); } } @@ -304,7 +304,7 @@ TEST_F(HexagonThreadManagerTest, thread_order_signal_wait) { htm->Dispatch(streams[5], append, &cmd5); htm->WaitOnThreads(); for (int i = 0; i < threads; ++i) { - CHECK_EQ(arr[i], i); + TVM_FFI_ICHECK_EQ(arr[i], i); } } @@ -334,7 +334,7 @@ TEST_F(HexagonThreadManagerTest, dispatch_writes) { htm->Start(); htm->WaitOnThreads(); for (int i = 0; i < streams.size(); i++) { - CHECK_EQ(array[i], truth[i]); + TVM_FFI_ICHECK_EQ(array[i], truth[i]); } } @@ -344,18 +344,18 @@ TEST_F(HexagonThreadManagerTest, threads_for_resource_types) { TVMStreamHandle thread; thread = thread_manager->GetStreamHandleByResourceType(DMA_0); - CHECK(thread_manager->GetResourceTypeForStreamHandle(thread) == DMA_0); + TVM_FFI_ICHECK(thread_manager->GetResourceTypeForStreamHandle(thread) == DMA_0); thread = thread_manager->GetStreamHandleByResourceType(HTP_0); - CHECK(thread_manager->GetResourceTypeForStreamHandle(thread) == HTP_0); + TVM_FFI_ICHECK(thread_manager->GetResourceTypeForStreamHandle(thread) == HTP_0); thread = thread_manager->GetStreamHandleByResourceType(HVX_0); - CHECK(thread_manager->GetResourceTypeForStreamHandle(thread) == HVX_0); + TVM_FFI_ICHECK(thread_manager->GetResourceTypeForStreamHandle(thread) == HVX_0); thread = thread_manager->GetStreamHandleByResourceType(HVX_1); - CHECK(thread_manager->GetResourceTypeForStreamHandle(thread) == HVX_1); + TVM_FFI_ICHECK(thread_manager->GetResourceTypeForStreamHandle(thread) == HVX_1); thread = thread_manager->GetStreamHandleByResourceType(HVX_2); - CHECK(thread_manager->GetResourceTypeForStreamHandle(thread) == HVX_2); + TVM_FFI_ICHECK(thread_manager->GetResourceTypeForStreamHandle(thread) == HVX_2); thread = thread_manager->GetStreamHandleByResourceType(HVX_3); - CHECK(thread_manager->GetResourceTypeForStreamHandle(thread) == HVX_3); - EXPECT_THROW(thread_manager->GetStreamHandleByResourceType(NONE), InternalError); + TVM_FFI_ICHECK(thread_manager->GetResourceTypeForStreamHandle(thread) == HVX_3); + EXPECT_THROW(thread_manager->GetStreamHandleByResourceType(NONE), tvm::ffi::Error); thread = reinterpret_cast(6); - EXPECT_THROW(thread_manager->GetResourceTypeForStreamHandle(thread), InternalError); + EXPECT_THROW(thread_manager->GetResourceTypeForStreamHandle(thread), tvm::ffi::Error); } diff --git a/tests/cpp-runtime/hexagon/hexagon_vtcm_pool_tests.cc b/tests/cpp-runtime/hexagon/hexagon_vtcm_pool_tests.cc index baa4035e47fb..2743d6e64e95 100644 --- a/tests/cpp-runtime/hexagon/hexagon_vtcm_pool_tests.cc +++ b/tests/cpp-runtime/hexagon/hexagon_vtcm_pool_tests.cc @@ -48,35 +48,36 @@ TEST_F(HexagonVtcmPoolTest, basic) { void* ptr; void* ptr2; - CHECK(device_bytes >= max_bytes) << "VTCM device size " << device_bytes - << " not greater than or equal to allocated size " << max_bytes; + TVM_FFI_ICHECK(device_bytes >= max_bytes) + << "VTCM device size " << device_bytes << " not greater than or equal to allocated size " + << max_bytes; ptr = vtcm_pool->Allocate(max_bytes); - CHECK((reinterpret_cast(ptr) & 0x7FF) == 0) + TVM_FFI_ICHECK((reinterpret_cast(ptr) & 0x7FF) == 0) << "Must be multiple of 2k " << ptr << " " << max_bytes; vtcm_pool->Free(ptr, max_bytes); ptr = vtcm_pool->Allocate(two_k_block); - CHECK((reinterpret_cast(ptr) & 0x7FF) == 0) + TVM_FFI_ICHECK((reinterpret_cast(ptr) & 0x7FF) == 0) << "Must be multiple of 2k " << ptr << " " << two_k_block; vtcm_pool->Free(ptr, two_k_block); ptr = vtcm_pool->Allocate(one_k_block); - CHECK((reinterpret_cast(ptr) & 0x7F) == 0) + TVM_FFI_ICHECK((reinterpret_cast(ptr) & 0x7F) == 0) << "Must be multiple of 128 " << ptr << " " << one_k_block; vtcm_pool->Free(ptr, one_k_block); ptr = vtcm_pool->Allocate(min_bytes); - CHECK((reinterpret_cast(ptr) & 0x7F) == 0) + TVM_FFI_ICHECK((reinterpret_cast(ptr) & 0x7F) == 0) << "Must be multiple of 128 " << ptr << " " << min_bytes; ptr2 = vtcm_pool->Allocate(one_k_block); - CHECK((reinterpret_cast(ptr) & 0x7F) == 0) + TVM_FFI_ICHECK((reinterpret_cast(ptr) & 0x7F) == 0) << "Must be multiple of 128 " << ptr2 << " " << one_k_block; vtcm_pool->Free(ptr, min_bytes); vtcm_pool->Free(ptr2, one_k_block); - EXPECT_THROW(ptr = vtcm_pool->Allocate(1), InternalError); + EXPECT_THROW(ptr = vtcm_pool->Allocate(1), tvm::ffi::Error); } TEST_F(HexagonVtcmPoolTest, small_allocations) { @@ -95,7 +96,7 @@ TEST_F(HexagonVtcmPoolTest, small_allocations) { ptr3 = vtcm_pool->Allocate(max_bytes - min_bytes - two_k_block); // Should be no more memory left - EXPECT_THROW(ptr4 = vtcm_pool->Allocate(min_bytes), InternalError); + EXPECT_THROW(ptr4 = vtcm_pool->Allocate(min_bytes), tvm::ffi::Error); vtcm_pool->Free(ptr1, min_bytes); vtcm_pool->Free(ptr2, two_k_block); @@ -108,19 +109,19 @@ TEST_F(HexagonVtcmPoolTest, small_allocations) { TEST_F(HexagonVtcmPoolTest, no_free_vtcm) { void* ptr = vtcm_pool->Allocate(max_bytes); - EXPECT_THROW(vtcm_pool->Allocate(min_bytes), InternalError); + EXPECT_THROW(vtcm_pool->Allocate(min_bytes), tvm::ffi::Error); vtcm_pool->Free(ptr, max_bytes); } TEST_F(HexagonVtcmPoolTest, not_enough_free_vtcm) { void* ptr = vtcm_pool->Allocate(max_bytes - two_k_block); - EXPECT_THROW(vtcm_pool->Allocate(two_k_block * 2), InternalError); + EXPECT_THROW(vtcm_pool->Allocate(two_k_block * 2), tvm::ffi::Error); vtcm_pool->Free(ptr, max_bytes - two_k_block); } TEST_F(HexagonVtcmPoolTest, free_with_wrong_size) { void* ptr = vtcm_pool->Allocate(two_k_block * 2); - EXPECT_THROW(vtcm_pool->Free(ptr, two_k_block), InternalError); + EXPECT_THROW(vtcm_pool->Free(ptr, two_k_block), tvm::ffi::Error); vtcm_pool->Free(ptr, two_k_block * 2); } @@ -137,22 +138,22 @@ TEST_F(HexagonVtcmPoolTest, free_alloc_combinations) { ptr4 = vtcm_pool->Allocate(max_less_3_blocks); // Make sure pointers are 2k apart from each other - CHECK(static_cast(ptr1) + two_k_block == static_cast(ptr2)); - CHECK(static_cast(ptr2) + two_k_block == static_cast(ptr3)); - CHECK(static_cast(ptr3) + two_k_block == static_cast(ptr4)); + TVM_FFI_ICHECK(static_cast(ptr1) + two_k_block == static_cast(ptr2)); + TVM_FFI_ICHECK(static_cast(ptr2) + two_k_block == static_cast(ptr3)); + TVM_FFI_ICHECK(static_cast(ptr3) + two_k_block == static_cast(ptr4)); // Free 2, realloc it, make sure it is the same as before vtcm_pool->Free(ptr2, two_k_block); new_ptr = vtcm_pool->Allocate(two_k_block); - CHECK(new_ptr == ptr2); + TVM_FFI_ICHECK(new_ptr == ptr2); // Free 1 and 2, re-alloc and make sure they are the same vtcm_pool->Free(ptr1, two_k_block); vtcm_pool->Free(ptr2, two_k_block); new_ptr = vtcm_pool->Allocate(two_k_block); - CHECK(new_ptr == ptr1); + TVM_FFI_ICHECK(new_ptr == ptr1); new_ptr = vtcm_pool->Allocate(two_k_block); - CHECK(new_ptr == ptr2); + TVM_FFI_ICHECK(new_ptr == ptr2); // Exercise different deletion scenarios vtcm_pool->Free(ptr2, two_k_block); @@ -214,10 +215,10 @@ TEST_F(HexagonVtcmPoolTest, find_smallest_allocation_combinations) { // Reallocate memory allocations and ensure that the smallest free allocations are used. new_ptr = vtcm_pool->Allocate(two_k_block); - CHECK(new_ptr == ptr2); + TVM_FFI_ICHECK(new_ptr == ptr2); new_ptr = vtcm_pool->Allocate(two_k_block); - CHECK(new_ptr == ptr3); + TVM_FFI_ICHECK(new_ptr == ptr3); vtcm_pool->Free(ptr1, two_k_block); vtcm_pool->Free(ptr2, two_k_block); @@ -236,10 +237,10 @@ TEST_F(HexagonVtcmPoolTest, find_smallest_allocation_combinations) { // Reallocate memory allocations and ensure that the smallest free allocations are used. new_ptr = vtcm_pool->Allocate(min_bytes); - CHECK(new_ptr == ptr2); + TVM_FFI_ICHECK(new_ptr == ptr2); new_ptr = vtcm_pool->Allocate(one_k_block); - CHECK(new_ptr == ptr3); + TVM_FFI_ICHECK(new_ptr == ptr3); vtcm_pool->Free(ptr1, min_bytes); vtcm_pool->Free(ptr2, min_bytes); @@ -264,22 +265,28 @@ TEST_F(HexagonVtcmPoolTest, vtcm_alignment) { // Valid alignments, sizes need to be adjusted ptr = test_hexbuffs->AllocateHexagonBuffer(1, 128, ffi::String("global")); - CHECK((reinterpret_cast(ptr) & 0x7F) == 0) << "Must be multiple of 128 " << ptr; + TVM_FFI_ICHECK((reinterpret_cast(ptr) & 0x7F) == 0) + << "Must be multiple of 128 " << ptr; ptr = test_hexbuffs->AllocateHexagonBuffer(127, 128, ffi::String("global")); - CHECK((reinterpret_cast(ptr) & 0x7F) == 0) << "Must be multiple of 128 " << ptr; + TVM_FFI_ICHECK((reinterpret_cast(ptr) & 0x7F) == 0) + << "Must be multiple of 128 " << ptr; ptr = test_hexbuffs->AllocateHexagonBuffer(129, 128, ffi::String("global")); - CHECK((reinterpret_cast(ptr) & 0x7F) == 0) << "Must be multiple of 128 " << ptr; + TVM_FFI_ICHECK((reinterpret_cast(ptr) & 0x7F) == 0) + << "Must be multiple of 128 " << ptr; ptr = test_hexbuffs->AllocateHexagonBuffer(1, 2048, ffi::String("global")); - CHECK((reinterpret_cast(ptr) & 0x7FF) == 0) << "Must be multiple of 2k " << ptr; + TVM_FFI_ICHECK((reinterpret_cast(ptr) & 0x7FF) == 0) + << "Must be multiple of 2k " << ptr; ptr = test_hexbuffs->AllocateHexagonBuffer(2047, 2048, ffi::String("global")); - CHECK((reinterpret_cast(ptr) & 0x7FF) == 0) << "Must be multiple of 2k " << ptr; + TVM_FFI_ICHECK((reinterpret_cast(ptr) & 0x7FF) == 0) + << "Must be multiple of 2k " << ptr; ptr = test_hexbuffs->AllocateHexagonBuffer(2049, 2048, ffi::String("global")); - CHECK((reinterpret_cast(ptr) & 0x7FF) == 0) << "Must be multiple of 2k " << ptr; + TVM_FFI_ICHECK((reinterpret_cast(ptr) & 0x7FF) == 0) + << "Must be multiple of 2k " << ptr; test_hexbuffs.reset(); diff --git a/tests/cpp-runtime/hexagon/ring_buffer_tests.cc b/tests/cpp-runtime/hexagon/ring_buffer_tests.cc index a3abf82b863f..4ee003e373fb 100644 --- a/tests/cpp-runtime/hexagon/ring_buffer_tests.cc +++ b/tests/cpp-runtime/hexagon/ring_buffer_tests.cc @@ -46,7 +46,7 @@ class RingBufferTest : public ::testing::Test { }; TEST_F(RingBufferTest, zero_size_ring_buffer) { - ASSERT_THROW(RingBuffer(0, in_flight), InternalError); + ASSERT_THROW(RingBuffer(0, in_flight), tvm::ffi::Error); } TEST_F(RingBufferTest, in_flight) { ASSERT_EQ(ring_buff->InFlight(), 0); } @@ -201,10 +201,10 @@ class QueuedRingBufferTest : public RingBufferTest { }; TEST_F(QueuedRingBufferTest, invalid_queue) { - ASSERT_THROW(queued_ring_buff->Next(MAX_QUEUES), InternalError); - ASSERT_THROW(queued_ring_buff->InFlight(MAX_QUEUES), InternalError); - ASSERT_THROW(queued_ring_buff->StartGroup(MAX_QUEUES), InternalError); - ASSERT_THROW(queued_ring_buff->EndGroup(MAX_QUEUES), InternalError); + ASSERT_THROW(queued_ring_buff->Next(MAX_QUEUES), tvm::ffi::Error); + ASSERT_THROW(queued_ring_buff->InFlight(MAX_QUEUES), tvm::ffi::Error); + ASSERT_THROW(queued_ring_buff->StartGroup(MAX_QUEUES), tvm::ffi::Error); + ASSERT_THROW(queued_ring_buff->EndGroup(MAX_QUEUES), tvm::ffi::Error); } TEST_F(QueuedRingBufferTest, two_queues) { @@ -228,22 +228,22 @@ TEST_F(QueuedRingBufferTest, two_queues) { } TEST_F(QueuedRingBufferTest, group_end_before_group_start) { - ASSERT_THROW(queued_ring_buff->EndGroup(0), InternalError); + ASSERT_THROW(queued_ring_buff->EndGroup(0), tvm::ffi::Error); } TEST_F(QueuedRingBufferTest, group_restart) { queued_ring_buff->StartGroup(0); - ASSERT_THROW(queued_ring_buff->StartGroup(0), InternalError); + ASSERT_THROW(queued_ring_buff->StartGroup(0), tvm::ffi::Error); } TEST_F(QueuedRingBufferTest, zero_size_group) { queued_ring_buff->StartGroup(0); - ASSERT_THROW(queued_ring_buff->EndGroup(0), InternalError); + ASSERT_THROW(queued_ring_buff->EndGroup(0), tvm::ffi::Error); } TEST_F(QueuedRingBufferTest, in_flight_before_group_end) { queued_ring_buff->StartGroup(0); - ASSERT_THROW(queued_ring_buff->InFlight(0), InternalError); + ASSERT_THROW(queued_ring_buff->InFlight(0), tvm::ffi::Error); } TEST_F(QueuedRingBufferTest, group_of_one) { diff --git a/tests/cpp-runtime/opencl/opencl_nativeptr.cc b/tests/cpp-runtime/opencl/opencl_nativeptr.cc index 1694de418b5c..5b9bb24a111e 100644 --- a/tests/cpp-runtime/opencl/opencl_nativeptr.cc +++ b/tests/cpp-runtime/opencl/opencl_nativeptr.cc @@ -55,8 +55,8 @@ TEST(OpenCLNatvePtr, data_loop) { cpu_arr.CopyTo(cl_arr); void* nptr = workspace->GetNativePtr(cl_arr); for (size_t i = 0; i < 1024; ++i) { - ICHECK_LT(std::fabs(static_cast(cpu_arr->data)[i] - static_cast(nptr)[i]), - 1e-5); + TVM_FFI_ICHECK_LT( + std::fabs(static_cast(cpu_arr->data)[i] - static_cast(nptr)[i]), 1e-5); } // Random initialize cl ndarray @@ -66,8 +66,8 @@ TEST(OpenCLNatvePtr, data_loop) { // Do a roundtrip from native ptr to cl arr to cpu array. cl_arr.CopyTo(cpu_arr); for (size_t i = 0; i < 1024; ++i) { - ICHECK_LT(std::fabs(static_cast(cpu_arr->data)[i] - static_cast(nptr)[i]), - 1e-5); + TVM_FFI_ICHECK_LT( + std::fabs(static_cast(cpu_arr->data)[i] - static_cast(nptr)[i]), 1e-5); } } diff --git a/tests/cpp-runtime/opencl/opencl_timer_test.cc b/tests/cpp-runtime/opencl/opencl_timer_test.cc index ec038be5406c..e29906e91f91 100644 --- a/tests/cpp-runtime/opencl/opencl_timer_test.cc +++ b/tests/cpp-runtime/opencl/opencl_timer_test.cc @@ -59,5 +59,5 @@ TEST(OpenCLTimerNode, nested_timers) { delete[] tmp_buf; int64_t elapsed = init_timer->SyncAndGetElapsedNanos(); - CHECK_EQ(elapsed, nested_time_sum); + TVM_FFI_ICHECK_EQ(elapsed, nested_time_sum); } diff --git a/tests/cpp-runtime/opencl/texture_copy_test.cc b/tests/cpp-runtime/opencl/texture_copy_test.cc index 001e65b90126..3cc4f2e5b8a5 100644 --- a/tests/cpp-runtime/opencl/texture_copy_test.cc +++ b/tests/cpp-runtime/opencl/texture_copy_test.cc @@ -84,7 +84,7 @@ TEST(TextureCopy, HostDeviceRT) { cpu_arr0.CopyTo(opencl_txarr0); opencl_txarr0.CopyTo(cpu_arr1); for (size_t i = 0; i < size; ++i) { - ICHECK_LT( + TVM_FFI_ICHECK_LT( std::fabs(static_cast(cpu_arr1->data)[i] - static_cast(cpu_arr0->data)[i]), 1e-5); } @@ -127,9 +127,9 @@ TEST_F(TextureCopyTest, ViewBufferAsBuffer) { // Copy from OpenCLBuffer opencl_memobj.CopyTo(cpu_arr_ret); for (size_t i = 0; i < size; i++) { - ICHECK_LT(std::fabs(static_cast(cpu_arr->data)[i] - - static_cast(cpu_arr_ret->data)[i]), - 1e-5); + TVM_FFI_ICHECK_LT(std::fabs(static_cast(cpu_arr->data)[i] - + static_cast(cpu_arr_ret->data)[i]), + 1e-5); } /* Check view object round trip */ @@ -142,9 +142,9 @@ TEST_F(TextureCopyTest, ViewBufferAsBuffer) { // Copy from OpenCLBuffer opencl_memview.CopyTo(cpu_arr_ret); for (size_t i = 0; i < size; i++) { - ICHECK_LT(std::fabs(static_cast(cpu_arr->data)[i] - - static_cast(cpu_arr_ret->data)[i]), - 1e-5); + TVM_FFI_ICHECK_LT(std::fabs(static_cast(cpu_arr->data)[i] - + static_cast(cpu_arr_ret->data)[i]), + 1e-5); } } @@ -184,9 +184,9 @@ TEST_F(TextureCopyTest, ViewBufferAsImage) { // Copy from OpenCLBuffer opencl_buf_obj.CopyTo(cpu_arr_ret); for (size_t i = 0; i < size; i++) { - ICHECK_LT(std::fabs(static_cast(cpu_arr->data)[i] - - static_cast(cpu_arr_ret->data)[i]), - 1e-5); + TVM_FFI_ICHECK_LT(std::fabs(static_cast(cpu_arr->data)[i] - + static_cast(cpu_arr_ret->data)[i]), + 1e-5); } /* Check view object round trip */ @@ -199,9 +199,9 @@ TEST_F(TextureCopyTest, ViewBufferAsImage) { // Copy from OpenCLBuffer opencl_img_obj.CopyTo(cpu_arr_ret); for (size_t i = 0; i < size; i++) { - ICHECK_LT(std::fabs(static_cast(cpu_arr->data)[i] - - static_cast(cpu_arr_ret->data)[i]), - 1e-5); + TVM_FFI_ICHECK_LT(std::fabs(static_cast(cpu_arr->data)[i] - + static_cast(cpu_arr_ret->data)[i]), + 1e-5); } } @@ -242,9 +242,9 @@ TEST_F(TextureCopyTest, ViewImageAsBuffer) { // Copy from OpenCLBuffer opencl_buf_obj.CopyTo(cpu_arr_ret); for (size_t i = 0; i < size; i++) { - ICHECK_LT(std::fabs(static_cast(cpu_arr->data)[i] - - static_cast(cpu_arr_ret->data)[i]), - 1e-5); + TVM_FFI_ICHECK_LT(std::fabs(static_cast(cpu_arr->data)[i] - + static_cast(cpu_arr_ret->data)[i]), + 1e-5); } /* Check view object round trip */ @@ -257,9 +257,9 @@ TEST_F(TextureCopyTest, ViewImageAsBuffer) { // Copy from OpenCLBuffer opencl_img_obj.CopyTo(cpu_arr_ret); for (size_t i = 0; i < size; i++) { - ICHECK_LT(std::fabs(static_cast(cpu_arr->data)[i] - - static_cast(cpu_arr_ret->data)[i]), - 1e-5); + TVM_FFI_ICHECK_LT(std::fabs(static_cast(cpu_arr->data)[i] - + static_cast(cpu_arr_ret->data)[i]), + 1e-5); } } @@ -300,9 +300,9 @@ TEST_F(TextureCopyTest, ViewImageAsImage) { // Copy from OpenCLBuffer opencl_img_obj_1.CopyTo(cpu_arr_ret); for (size_t i = 0; i < size; i++) { - ICHECK_LT(std::fabs(static_cast(cpu_arr->data)[i] - - static_cast(cpu_arr_ret->data)[i]), - 1e-5); + TVM_FFI_ICHECK_LT(std::fabs(static_cast(cpu_arr->data)[i] - + static_cast(cpu_arr_ret->data)[i]), + 1e-5); } /* Check view object round trip */ @@ -315,8 +315,8 @@ TEST_F(TextureCopyTest, ViewImageAsImage) { // Copy from OpenCLBuffer opencl_img_obj_2.CopyTo(cpu_arr_ret); for (size_t i = 0; i < size; i++) { - ICHECK_LT(std::fabs(static_cast(cpu_arr->data)[i] - - static_cast(cpu_arr_ret->data)[i]), - 1e-5); + TVM_FFI_ICHECK_LT(std::fabs(static_cast(cpu_arr->data)[i] - + static_cast(cpu_arr_ret->data)[i]), + 1e-5); } } diff --git a/tests/cpp/arith_simplify_test.cc b/tests/cpp/arith_simplify_test.cc index 23bcd8a7a7e5..495739d76608 100644 --- a/tests/cpp/arith_simplify_test.cc +++ b/tests/cpp/arith_simplify_test.cc @@ -27,11 +27,11 @@ TEST(Simplify, MinMax) { auto x = tvm::te::var("x"); auto e1 = (tvm::max(x, 1) - tvm::max(x, 1)); auto e1s = ana.canonical_simplify(e1); - ICHECK(tvm::tir::is_zero(e1s)); + TVM_FFI_ICHECK(tvm::tir::is_zero(e1s)); auto e2 = (x * tvm::min(x, 1)) - (x * tvm::min(x, 1)); auto e2s = ana.canonical_simplify(e2); - ICHECK(tvm::tir::is_zero(e2s)); + TVM_FFI_ICHECK(tvm::tir::is_zero(e2s)); } TEST(Simplify, Mul) { @@ -39,7 +39,7 @@ TEST(Simplify, Mul) { auto x = tvm::te::var("x"); auto e = (x * x) - (x * x); auto es = ana.canonical_simplify(e); - ICHECK(tvm::tir::is_zero(es)); + TVM_FFI_ICHECK(tvm::tir::is_zero(es)); } TEST(Simplify, Mod) { @@ -51,7 +51,7 @@ TEST(Simplify, Mod) { // and therefore, the constant folding will be attempted in CanonicalSimplify auto mod = ana.canonical_simplify(tvm::tir::Mod(x, y)); auto es = ana.canonical_simplify(mod - x); - ICHECK(tvm::tir::is_zero(es)); + TVM_FFI_ICHECK(tvm::tir::is_zero(es)); } TEST(ConstantFold, Broadcast) { diff --git a/tests/cpp/expr_test.cc b/tests/cpp/expr_test.cc index 05fbd5ce548c..67c9fe99cf30 100644 --- a/tests/cpp/expr_test.cc +++ b/tests/cpp/expr_test.cc @@ -31,8 +31,8 @@ TEST(Expr, Basic) { PrimExpr zz = Downcast(tmp); std::ostringstream os; os << z; - ICHECK(zz.same_as(z)); - ICHECK(os.str() == "T.max(x + 1 + 2, 100)"); + TVM_FFI_ICHECK(zz.same_as(z)); + TVM_FFI_ICHECK(os.str() == "T.max(x + 1 + 2, 100)"); } TEST(Expr, VarTypeAnnotation) { @@ -41,8 +41,8 @@ TEST(Expr, VarTypeAnnotation) { Var x("x", DataType::Float(32)); Var y("y", PrimType(DataType::Float(32))); StructuralEqual checker; - ICHECK(checker(x->dtype, y->dtype)); - ICHECK(checker(x->type_annotation, y->type_annotation)); + TVM_FFI_ICHECK(checker(x->dtype, y->dtype)); + TVM_FFI_ICHECK(checker(x->type_annotation, y->type_annotation)); } TEST(ExprNodeRef, Basic) { @@ -51,5 +51,5 @@ TEST(ExprNodeRef, Basic) { Var x("x"); PrimExpr z = max(x + 1 + 2, 100); const tir::MaxNode* op = z.as(); - ICHECK(ffi::GetRef(op).same_as(z)); + TVM_FFI_ICHECK(ffi::GetRef(op).same_as(z)); } diff --git a/tests/cpp/ir_functor_test.cc b/tests/cpp/ir_functor_test.cc index 83418d352b5b..9dafe025f74b 100644 --- a/tests/cpp/ir_functor_test.cc +++ b/tests/cpp/ir_functor_test.cc @@ -38,8 +38,8 @@ TEST(IRF, Basic) { NodeFunctor f; f.set_dispatch([](const ObjectRef& n, int b) { return b; }); f.set_dispatch([](const ObjectRef& n, int b) { return b + 2; }); - ICHECK_EQ(f(x, 2), 2); - ICHECK_EQ(f(z, 2), 4); + TVM_FFI_ICHECK_EQ(f(x, 2), 2); + TVM_FFI_ICHECK_EQ(f(z, 2), 4); } TEST(IRF, CountVar) { @@ -52,7 +52,7 @@ TEST(IRF, CountVar) { tir::PostOrderVisit(z, [&n_var](const ObjectRef& n) { if (n.as()) ++n_var; }); - ICHECK_EQ(n_var, 2); + TVM_FFI_ICHECK_EQ(n_var, 2); } TEST(IRF, PreOrderVisit) { @@ -78,7 +78,7 @@ TEST(IRF, PreOrderVisit) { } else if (int_imm->value == 1) { body_visited = true; } else { - LOG(FATAL) << "Unreachable"; + TVM_FFI_THROW(InternalError) << "Unreachable"; } } } @@ -104,11 +104,11 @@ TEST(IRF, ExprTransform) { } }; MyExprFunctor f; - ICHECK_EQ(f(x, 2), 2); - ICHECK_EQ(f(z, 2), 3); + TVM_FFI_ICHECK_EQ(f(x, 2), 2); + TVM_FFI_ICHECK_EQ(f(z, 2), 3); try { f(z - 1, 2); - LOG(FATAL) << "should fail"; + TVM_FFI_THROW(InternalError) << "should fail"; } catch (Error&) { } } @@ -134,7 +134,7 @@ TEST(IRF, ExprVisit) { }; MyVisitor v; v.VisitStmt(Evaluate(z)); - ICHECK_EQ(v.count, 1); + TVM_FFI_ICHECK_EQ(v.count, 1); } TEST(IRF, StmtVisitor) { @@ -156,7 +156,7 @@ TEST(IRF, StmtVisitor) { return Allocate(buffer, dtype, {z, z}, const_true(), body); }; v(fmaketest()); - ICHECK_EQ(v.count, 3); + TVM_FFI_ICHECK_EQ(v.count, 3); { // tests for block and block_realize @@ -175,7 +175,7 @@ TEST(IRF, StmtVisitor) { v.count = 0; v(block_realize); - ICHECK_EQ(v.count, 9); + TVM_FFI_ICHECK_EQ(v.count, 9); } } @@ -218,14 +218,14 @@ TEST(IRF, StmtMutator) { ffi::Array arr{std::move(body), body2, body2}; auto* arrptr = arr.get(); arr.MutateByApply([&](Stmt s) { return v(std::move(s)); }); - ICHECK(arr.get() == arrptr); + TVM_FFI_ICHECK(arr.get() == arrptr); // inplace update body - ICHECK(arr[0].as()->extents[1].same_as(x)); - ICHECK(arr[0].as()->extents.get() == extentptr); + TVM_FFI_ICHECK(arr[0].as()->extents[1].same_as(x)); + TVM_FFI_ICHECK(arr[0].as()->extents.get() == extentptr); // copy because there is additional refs - ICHECK(!arr[0].as()->body.same_as(bref)); - ICHECK(arr[0].as()->body.as()->value.same_as(x)); - ICHECK(bref.as()->value.as()); + TVM_FFI_ICHECK(!arr[0].as()->body.same_as(bref)); + TVM_FFI_ICHECK(arr[0].as()->body.as()->value.same_as(x)); + TVM_FFI_ICHECK(bref.as()->value.as()); } { ffi::Array arr{fmakealloc()}; @@ -233,29 +233,29 @@ TEST(IRF, StmtMutator) { ffi::Array arr2 = arr; auto* arrptr = arr.get(); arr.MutateByApply([&](Stmt s) { return v(std::move(s)); }); - ICHECK(arr.get() != arrptr); - ICHECK(arr[0].as()->extents[1].same_as(x)); - ICHECK(!arr2[0].as()->extents[1].same_as(x)); + TVM_FFI_ICHECK(arr.get() != arrptr); + TVM_FFI_ICHECK(arr[0].as()->extents[1].same_as(x)); + TVM_FFI_ICHECK(!arr2[0].as()->extents[1].same_as(x)); // mutate but no content change. arr2 = arr; arr.MutateByApply([&](Stmt s) { return v(std::move(s)); }); - ICHECK(arr2.get() == arr.get()); + TVM_FFI_ICHECK(arr2.get() == arr.get()); } { ffi::Array arr{fmakeif()}; arr.MutateByApply([&](Stmt s) { return v(std::move(s)); }); - ICHECK(arr[0].as()->else_case.as()->value.same_as(x)); + TVM_FFI_ICHECK(arr[0].as()->else_case.as()->value.same_as(x)); // mutate but no content change. auto arr2 = arr; arr.MutateByApply([&](Stmt s) { return v(std::move(s)); }); - ICHECK(arr2.get() == arr.get()); + TVM_FFI_ICHECK(arr2.get() == arr.get()); } { auto body = Evaluate(Call(DataType::Int(32), builtin::call_extern(), {StringImm("xyz"), x + 1})); auto res = v(std::move(body)); - ICHECK(res.as()->value.as()->args[1].same_as(x)); + TVM_FFI_ICHECK(res.as()->value.as()->args[1].same_as(x)); } { Stmt body = fmakealloc(); @@ -267,9 +267,9 @@ TEST(IRF, StmtMutator) { body = SeqStmt({body, body2}); body = v(std::move(body)); // the seq get flattened - ICHECK(body.as()->size() == 3); - ICHECK(body.as()->seq[0].as()->extents.get() == extentptr); - ICHECK(body.as()->seq[1].get() == ref2); + TVM_FFI_ICHECK(body.as()->size() == 3); + TVM_FFI_ICHECK(body.as()->seq[0].as()->extents.get() == extentptr); + TVM_FFI_ICHECK(body.as()->seq[1].get() == ref2); } { @@ -283,7 +283,7 @@ TEST(IRF, StmtMutator) { body = SeqStmt({body, body2}); body = v(std::move(body)); // the seq get flattened - ICHECK(body.as()->seq[0].as()->extents.get() != extentptr); + TVM_FFI_ICHECK(body.as()->seq[0].as()->extents.get() != extentptr); } { @@ -302,11 +302,13 @@ TEST(IRF, StmtMutator) { body = v(std::move(block_realize)); // the body should be changed SBlock new_block = body.as()->block; - ICHECK(new_block->body.as()->body.as()->extents[1].same_as(x)); - ICHECK(new_block->init.as()->body.as()->extents[1].same_as(x)); - ICHECK(new_block->reads[0]->region[0]->min.same_as(x)); - ICHECK(new_block->writes[0]->region[0]->min.same_as(x)); - ICHECK(new_block->match_buffers[0]->source->region[0]->min.same_as(x)); + TVM_FFI_ICHECK( + new_block->body.as()->body.as()->extents[1].same_as(x)); + TVM_FFI_ICHECK( + new_block->init.as()->body.as()->extents[1].same_as(x)); + TVM_FFI_ICHECK(new_block->reads[0]->region[0]->min.same_as(x)); + TVM_FFI_ICHECK(new_block->writes[0]->region[0]->min.same_as(x)); + TVM_FFI_ICHECK(new_block->match_buffers[0]->source->region[0]->min.same_as(x)); } } @@ -339,7 +341,7 @@ TEST(IRF, Substitute) { return std::nullopt; }; BufferLoad new_buffer_load = Downcast(Substitute(buffer_load, f_subst)); - ICHECK(new_buffer_load->buffer->data.same_as(y)); + TVM_FFI_ICHECK(new_buffer_load->buffer->data.same_as(y)); } { @@ -348,6 +350,6 @@ TEST(IRF, Substitute) { auto f_subst = [&](const Var& var) -> ffi::Optional { return var; }; PrimExpr new_expr = Substitute(expr, f_subst); // the expression is not changed - ICHECK(new_expr.same_as(expr)); + TVM_FFI_ICHECK(new_expr.same_as(expr)); } } diff --git a/tests/cpp/ndarray_test.cc b/tests/cpp/ndarray_test.cc index c2452f9146b1..fdb064b4a46b 100644 --- a/tests/cpp/ndarray_test.cc +++ b/tests/cpp/ndarray_test.cc @@ -30,7 +30,7 @@ TEST(TensorTest, IsContiguous_ContiguousStride) { int64_t strides[] = {10, 1}; managed_tensor->dl_tensor.strides = strides; - ICHECK(runtime::IsContiguous(managed_tensor->dl_tensor)); + TVM_FFI_ICHECK(runtime::IsContiguous(managed_tensor->dl_tensor)); managed_tensor->deleter(managed_tensor); } @@ -41,7 +41,7 @@ TEST(TensorTest, IsContiguous_NullStride) { managed_tensor->dl_tensor.strides = nullptr; - ICHECK(runtime::IsContiguous(managed_tensor->dl_tensor)); + TVM_FFI_ICHECK(runtime::IsContiguous(managed_tensor->dl_tensor)); managed_tensor->deleter(managed_tensor); } @@ -53,7 +53,7 @@ TEST(TensorTest, IsContiguous_AnyStrideForSingular) { int64_t strides[] = {10, 1, 1}; // strides[1] is normalized to 1 because shape[1] == 1. managed_tensor->dl_tensor.strides = strides; - ICHECK(runtime::IsContiguous(managed_tensor->dl_tensor)); + TVM_FFI_ICHECK(runtime::IsContiguous(managed_tensor->dl_tensor)); managed_tensor->dl_tensor.strides = nullptr; managed_tensor->deleter(managed_tensor); @@ -66,7 +66,7 @@ TEST(TensorTest, IsContiguous_UncontiguousStride) { int64_t strides[] = {1, 1, 1}; managed_tensor->dl_tensor.strides = strides; - ICHECK(!runtime::IsContiguous(managed_tensor->dl_tensor)); + TVM_FFI_ICHECK(!runtime::IsContiguous(managed_tensor->dl_tensor)); managed_tensor->dl_tensor.strides = nullptr; managed_tensor->deleter(managed_tensor); diff --git a/tests/cpp/nested_msg_test.cc b/tests/cpp/nested_msg_test.cc index c9628daf0d80..b1f7b80c996a 100644 --- a/tests/cpp/nested_msg_test.cc +++ b/tests/cpp/nested_msg_test.cc @@ -233,7 +233,7 @@ TEST(NestedMsg, NestedMsgToExpr) { NestedMsg msg = {c0, {c0, c1}, {c0, {c1, c2}}}; auto expr = NestedMsgToExpr(msg, [&](ffi::Optional leaf) { - ICHECK(leaf.defined()); + TVM_FFI_ICHECK(leaf.defined()); int value = leaf.value().IntValue(); switch (value) { case 0: diff --git a/tests/cpp/object_protocol_test.cc b/tests/cpp/object_protocol_test.cc index fc02fb036bcf..1cf8cfb0e9be 100644 --- a/tests/cpp/object_protocol_test.cc +++ b/tests/cpp/object_protocol_test.cc @@ -61,26 +61,26 @@ TEST(ObjectHierachy, Basic) { using namespace tvm::ffi; ObjectRef refA(ffi::make_object()); - ICHECK_EQ(refA->type_index(), ObjA::RuntimeTypeIndex()); - ICHECK(refA.as() != nullptr); - ICHECK(refA.as() != nullptr); - ICHECK(refA.as() != nullptr); - ICHECK(refA.as() == nullptr); - ICHECK(refA.as() == nullptr); + TVM_FFI_ICHECK_EQ(refA->type_index(), ObjA::RuntimeTypeIndex()); + TVM_FFI_ICHECK(refA.as() != nullptr); + TVM_FFI_ICHECK(refA.as() != nullptr); + TVM_FFI_ICHECK(refA.as() != nullptr); + TVM_FFI_ICHECK(refA.as() == nullptr); + TVM_FFI_ICHECK(refA.as() == nullptr); ObjectRef refAA(ffi::make_object()); - ICHECK_EQ(refAA->type_index(), ObjAA::RuntimeTypeIndex()); - ICHECK(refAA.as() != nullptr); - ICHECK(refAA.as() != nullptr); - ICHECK(refAA.as() != nullptr); - ICHECK(refAA.as() != nullptr); - ICHECK(refAA.as() == nullptr); + TVM_FFI_ICHECK_EQ(refAA->type_index(), ObjAA::RuntimeTypeIndex()); + TVM_FFI_ICHECK(refAA.as() != nullptr); + TVM_FFI_ICHECK(refAA.as() != nullptr); + TVM_FFI_ICHECK(refAA.as() != nullptr); + TVM_FFI_ICHECK(refAA.as() != nullptr); + TVM_FFI_ICHECK(refAA.as() == nullptr); ObjectRef refB(ffi::make_object()); - ICHECK_EQ(refB->type_index(), ObjB::RuntimeTypeIndex()); - ICHECK(refB.as() != nullptr); - ICHECK(refB.as() != nullptr); - ICHECK(refB.as() == nullptr); - ICHECK(refB.as() == nullptr); - ICHECK(refB.as() != nullptr); + TVM_FFI_ICHECK_EQ(refB->type_index(), ObjB::RuntimeTypeIndex()); + TVM_FFI_ICHECK(refB.as() != nullptr); + TVM_FFI_ICHECK(refB.as() != nullptr); + TVM_FFI_ICHECK(refB.as() == nullptr); + TVM_FFI_ICHECK(refB.as() == nullptr); + TVM_FFI_ICHECK(refB.as() != nullptr); } diff --git a/tests/cpp/parallel_for_test.cc b/tests/cpp/parallel_for_test.cc index 2057044cc13f..0ca486e81598 100644 --- a/tests/cpp/parallel_for_test.cc +++ b/tests/cpp/parallel_for_test.cc @@ -35,7 +35,7 @@ TEST(ParallelFor, Basic) { } parallel_for(0, 10, [&b](int i) { b[i] = i; }); for (int i = 0; i < 10; i++) { - ICHECK_EQ(a[i], b[i]); + TVM_FFI_ICHECK_EQ(a[i], b[i]); } // Check for a large size of parallel @@ -44,7 +44,7 @@ TEST(ParallelFor, Basic) { } parallel_for(0, 1000, [&b](int i) { b[i] = i; }); for (int i = 0; i < 1000; i++) { - ICHECK_EQ(a[i], b[i]); + TVM_FFI_ICHECK_EQ(a[i], b[i]); } // Check for step != 1 @@ -54,7 +54,7 @@ TEST(ParallelFor, Basic) { parallel_for( 0, 1000, [&b](int i) { b[i] *= 2; }, 2); for (int i = 0; i < 1000; i++) { - ICHECK_EQ(a[i], b[i]); + TVM_FFI_ICHECK_EQ(a[i], b[i]); } } @@ -76,7 +76,7 @@ TEST(ParallelFor, NestedWithNormalForLoop) { }); for (int i = 0; i < 500; i++) { for (int j = 0; j < 500; j++) { - ICHECK_EQ(a[i][j], b[i][j]); + TVM_FFI_ICHECK_EQ(a[i][j], b[i][j]); } } @@ -85,7 +85,7 @@ TEST(ParallelFor, NestedWithNormalForLoop) { } for (int i = 0; i < 500; i++) { for (int j = 0; j < 500; j++) { - ICHECK_EQ(a[i][j], c[i][j]); + TVM_FFI_ICHECK_EQ(a[i][j], c[i][j]); } } } @@ -104,7 +104,7 @@ TEST(ParallelFor, NestedWithParallelFor) { } catch (const std::exception& e) { exception = true; } - ICHECK(exception); + TVM_FFI_ICHECK(exception); } TEST(ParallelFor, Exception) { @@ -112,11 +112,11 @@ TEST(ParallelFor, Exception) { bool exception = false; try { - parallel_for(0, 100, [](int i) { LOG(FATAL) << "error"; }); + parallel_for(0, 100, [](int i) { TVM_FFI_THROW(InternalError) << "error"; }); } catch (const std::exception& e) { exception = true; } - ICHECK(exception); + TVM_FFI_ICHECK(exception); } TEST(ParallelForDynamic, Basic) { @@ -125,7 +125,7 @@ TEST(ParallelForDynamic, Basic) { int num_threads = std::thread::hardware_concurrency(); parallel_for_dynamic(0, 1000, num_threads, [&a](int thread_id, int i) { a[i] = i; }); for (int i = 0; i < 1000; i++) { - ICHECK_EQ(a[i], i); + TVM_FFI_ICHECK_EQ(a[i], i); } } @@ -136,13 +136,13 @@ TEST(ParallelForDynamic, ExceptionOnMain) { try { parallel_for_dynamic(0, 10, num_threads, [](int thread_id, int task_id) { if (thread_id == 0) { - LOG(FATAL) << "Error"; + TVM_FFI_THROW(InternalError) << "Error"; } }); } catch (const std::exception& e) { exception = true; } - ICHECK(exception); + TVM_FFI_ICHECK(exception); } TEST(ParallelForDynamic, ExceptionOnArbitrary) { @@ -150,10 +150,11 @@ TEST(ParallelForDynamic, ExceptionOnArbitrary) { int num_threads = 3; bool exception = false; try { - parallel_for_dynamic(0, 100, num_threads, - [](int thread_id, int task_id) { LOG(FATAL) << "Error"; }); + parallel_for_dynamic(0, 100, num_threads, [](int thread_id, int task_id) { + TVM_FFI_THROW(InternalError) << "Error"; + }); } catch (const std::exception& e) { exception = true; } - ICHECK(exception); + TVM_FFI_ICHECK(exception); } diff --git a/tests/cpp/pattern_match_test.cc b/tests/cpp/pattern_match_test.cc index 70f806e241b5..a763b0b7002b 100644 --- a/tests/cpp/pattern_match_test.cc +++ b/tests/cpp/pattern_match_test.cc @@ -35,95 +35,96 @@ TEST(Pattern, Basic) { // arithmetics auto r = 1 + (y + 1); - ICHECK(!(px + (px + px)).Match(r)); - ICHECK(!(px + (py + py)).Match(r)); - ICHECK((px + (py + pz)).Match(r)); + TVM_FFI_ICHECK(!(px + (px + px)).Match(r)); + TVM_FFI_ICHECK(!(px + (py + py)).Match(r)); + TVM_FFI_ICHECK((px + (py + pz)).Match(r)); auto pattern = px + (py + pz); - ICHECK(pattern.Match(r)); + TVM_FFI_ICHECK(pattern.Match(r)); { - ICHECK((px + (py + px)).Match(r)); + TVM_FFI_ICHECK((px + (py + px)).Match(r)); auto rr = (px + py).Eval(); - ICHECK(tir::ExprDeepEqual()(rr, 1 + y)); - ICHECK(tir::ExprDeepEqual()(px.Eval() + py.Eval(), 1 + y)); + TVM_FFI_ICHECK(tir::ExprDeepEqual()(rr, 1 + y)); + TVM_FFI_ICHECK(tir::ExprDeepEqual()(px.Eval() + py.Eval(), 1 + y)); } { - ICHECK((px + max(py, px)).Match((x + 1) + max(y, (x + 1)))); - ICHECK(tir::ExprDeepEqual()(px.Eval(), x + 1)); + TVM_FFI_ICHECK((px + max(py, px)).Match((x + 1) + max(y, (x + 1)))); + TVM_FFI_ICHECK(tir::ExprDeepEqual()(px.Eval(), x + 1)); } - ICHECK(!(px + min(py, px)).Match((x + 1) + max(y, (x + 1)))); + TVM_FFI_ICHECK(!(px + min(py, px)).Match((x + 1) + max(y, (x + 1)))); - ICHECK((px + min(py, px)).Match(z + min(y, z))); - ICHECK((px + truncdiv(py, px * py)).Match(x + truncdiv(2, x * 2))); - ICHECK((px - truncmod(py, px * pz)).Match(x - truncmod(2, x * 2))); - ICHECK((px - floormod(py, px * PConst(2))).Match(x - floormod(2, x * 2))); + TVM_FFI_ICHECK((px + min(py, px)).Match(z + min(y, z))); + TVM_FFI_ICHECK((px + truncdiv(py, px * py)).Match(x + truncdiv(2, x * 2))); + TVM_FFI_ICHECK((px - truncmod(py, px * pz)).Match(x - truncmod(2, x * 2))); + TVM_FFI_ICHECK((px - floormod(py, px * PConst(2))).Match(x - floormod(2, x * 2))); // logicals - ICHECK((px == pz).Match(x == 1)); - ICHECK((px != pz).Match(x != 1)); - ICHECK((px > py).Match(x > y)); - ICHECK((px < py).Match(x < y)); - ICHECK((px <= py).Match(x <= y)); - ICHECK((px >= py).Match(x >= y)); - ICHECK((px >= py && px < pz).Match(x >= y && x < z)); - ICHECK((!(px > py || px != py)).Match(!(x > y || x != y))); + TVM_FFI_ICHECK((px == pz).Match(x == 1)); + TVM_FFI_ICHECK((px != pz).Match(x != 1)); + TVM_FFI_ICHECK((px > py).Match(x > y)); + TVM_FFI_ICHECK((px < py).Match(x < y)); + TVM_FFI_ICHECK((px <= py).Match(x <= y)); + TVM_FFI_ICHECK((px >= py).Match(x >= y)); + TVM_FFI_ICHECK((px >= py && px < pz).Match(x >= y && x < z)); + TVM_FFI_ICHECK((!(px > py || px != py)).Match(!(x > y || x != y))); { - ICHECK(select(px >= pz, py, py + pz).Match(tir::Select((x + 1) >= 1, y, y + 1))); - ICHECK(tir::ExprDeepEqual()(px.Eval(), x + 1)); + TVM_FFI_ICHECK(select(px >= pz, py, py + pz).Match(tir::Select((x + 1) >= 1, y, y + 1))); + TVM_FFI_ICHECK(tir::ExprDeepEqual()(px.Eval(), x + 1)); } // bit intrinsics { - ICHECK((px >> pz).Match(x >> 1)); - ICHECK(is_const_int(pz.Eval(), 1)); + TVM_FFI_ICHECK((px >> pz).Match(x >> 1)); + TVM_FFI_ICHECK(is_const_int(pz.Eval(), 1)); } - ICHECK(!(px >> pz).Match(x << 1)); - ICHECK((px << pz).Match(x << 1)); - ICHECK((px & pz).Match(x & 1)); - ICHECK((px | pz).Match(x | 1)); - ICHECK((px ^ pz).Match(x ^ 1)); - ICHECK((px - (~(py | (px * pz)))).Match(x - (~(2 | (x * 2))))); + TVM_FFI_ICHECK(!(px >> pz).Match(x << 1)); + TVM_FFI_ICHECK((px << pz).Match(x << 1)); + TVM_FFI_ICHECK((px & pz).Match(x & 1)); + TVM_FFI_ICHECK((px | pz).Match(x | 1)); + TVM_FFI_ICHECK((px ^ pz).Match(x ^ 1)); + TVM_FFI_ICHECK((px - (~(py | (px * pz)))).Match(x - (~(2 | (x * 2))))); // select { - ICHECK(select(px > pz, py, py + pz).Match(tir::Select(x > 1, y, y + 1))); - ICHECK(is_const_int(pz.Eval(), 1)); + TVM_FFI_ICHECK(select(px > pz, py, py + pz).Match(tir::Select(x > 1, y, y + 1))); + TVM_FFI_ICHECK(is_const_int(pz.Eval(), 1)); } - ICHECK(!select(px > pz, py, py + pz).Match(tir::Select(x > 2, y, y + 1))); - ICHECK(!select(px > pz, py, py).Match(tir::Select(x > 2, y, y + 1))); + TVM_FFI_ICHECK(!select(px > pz, py, py + pz).Match(tir::Select(x > 2, y, y + 1))); + TVM_FFI_ICHECK(!select(px > pz, py, py).Match(tir::Select(x > 2, y, y + 1))); { - ICHECK(select(px, py, pz).Match(tir::Select(x > 2, y, y + 1))); - ICHECK(tir::ExprDeepEqual()(pz.Eval(), y + 1)); + TVM_FFI_ICHECK(select(px, py, pz).Match(tir::Select(x > 2, y, y + 1))); + TVM_FFI_ICHECK(tir::ExprDeepEqual()(pz.Eval(), y + 1)); } // if_then_else { - ICHECK(if_then_else(px > pz, py, py + pz).Match(if_then_else(x > 1, y, y + 1))); - ICHECK(is_const_int(pz.Eval(), 1)); + TVM_FFI_ICHECK(if_then_else(px > pz, py, py + pz).Match(if_then_else(x > 1, y, y + 1))); + TVM_FFI_ICHECK(is_const_int(pz.Eval(), 1)); } // cast pattern { - ICHECK(!cast(PConst(DataType::Int(32)), px).Match(tir::Cast(DataType::Float(64), x))); - ICHECK(cast(pt, px).Match(tir::Cast(DataType::Float(64), x))); - ICHECK(pt.Eval() == DataType::Float(64)); + TVM_FFI_ICHECK( + !cast(PConst(DataType::Int(32)), px).Match(tir::Cast(DataType::Float(64), x))); + TVM_FFI_ICHECK(cast(pt, px).Match(tir::Cast(DataType::Float(64), x))); + TVM_FFI_ICHECK(pt.Eval() == DataType::Float(64)); auto zz = cast(pt, px).Eval(); - ICHECK((cast(pt, px) - cast(pt, py)) - .Match(tir::Cast(DataType::Float(64), x) - tir::Cast(DataType::Int(64), x))); + TVM_FFI_ICHECK((cast(pt, px) - cast(pt, py)) + .Match(tir::Cast(DataType::Float(64), x) - tir::Cast(DataType::Int(64), x))); auto expr = tir::Cast(DataType::Int(32), tir::Cast(DataType::Float(64), x)); - ICHECK(!(cast(pt, cast(pt, px))).Match(expr)); + TVM_FFI_ICHECK(!(cast(pt, cast(pt, px))).Match(expr)); } // ramp pattern { - ICHECK(ramp(px, PConst(1), planes).Match(tir::Ramp(x, 1, 10))); - ICHECK(planes.Eval().as()->value == 10); - ICHECK(ramp(px, PConst(1), planes).Match(tir::Ramp(x, 1, scalable_lanes))); - ICHECK((vscale * PConst(4)).Match(planes.Eval())); - ICHECK(!ramp(px, PConst(1), planes).Match(tir::Ramp(x, 2, 10))); + TVM_FFI_ICHECK(ramp(px, PConst(1), planes).Match(tir::Ramp(x, 1, 10))); + TVM_FFI_ICHECK(planes.Eval().as()->value == 10); + TVM_FFI_ICHECK(ramp(px, PConst(1), planes).Match(tir::Ramp(x, 1, scalable_lanes))); + TVM_FFI_ICHECK((vscale * PConst(4)).Match(planes.Eval())); + TVM_FFI_ICHECK(!ramp(px, PConst(1), planes).Match(tir::Ramp(x, 2, 10))); } // broadcast pattern { - ICHECK(broadcast(px, planes).Match(tir::Broadcast(x, 10))); - ICHECK(planes.Eval().as()->value == 10); - ICHECK(broadcast(px * py, planes).Match(tir::Broadcast(x * 10, 10))); - ICHECK(broadcast(px, planes).Match(tir::Broadcast(x, scalable_lanes))); - ICHECK((vscale * PConst(4)).Match(planes.Eval())); + TVM_FFI_ICHECK(broadcast(px, planes).Match(tir::Broadcast(x, 10))); + TVM_FFI_ICHECK(planes.Eval().as()->value == 10); + TVM_FFI_ICHECK(broadcast(px * py, planes).Match(tir::Broadcast(x * 10, 10))); + TVM_FFI_ICHECK(broadcast(px, planes).Match(tir::Broadcast(x, scalable_lanes))); + TVM_FFI_ICHECK((vscale * PConst(4)).Match(planes.Eval())); } } @@ -135,14 +136,14 @@ TEST(Pattern, IntImm) { { // We can match integer and Var, both of which are // special case container of Expr - ICHECK((v * c).Match(tx * 3)); - ICHECK_EQ(c.Eval()->value, 3); - ICHECK((v * 3).Match(tx * 3)); + TVM_FFI_ICHECK((v * c).Match(tx * 3)); + TVM_FFI_ICHECK_EQ(c.Eval()->value, 3); + TVM_FFI_ICHECK((v * 3).Match(tx * 3)); } // cannot match c to ty - ICHECK(!(v * c).Match(tx * ty)); + TVM_FFI_ICHECK(!(v * c).Match(tx * ty)); // cannot match tx + 1 to v - ICHECK(!(v * c).Match((tx + 1) * 3)); + TVM_FFI_ICHECK(!(v * c).Match((tx + 1) * 3)); } TEST(Pattern, MatchWithType) { @@ -153,8 +154,8 @@ TEST(Pattern, MatchWithType) { tir::Var y("y", DataType::Float(32)); tir::Var x_int("x", DataType::Int(32)); tir::Var y_int("y", DataType::Int(32)); - ICHECK(pat.Match(x + y * 2.0f)); - ICHECK(!pat.Match(x_int + y_int * 2)); + TVM_FFI_ICHECK(pat.Match(x + y * 2.0f)); + TVM_FFI_ICHECK(!pat.Match(x_int + y_int * 2)); // match vectorized expr with specified element dtype arith::PVecDataType vec_ty(DataType::Float(32)); @@ -163,6 +164,6 @@ TEST(Pattern, MatchWithType) { tir::Var vy("y", DataType::Float(32, 8)); tir::Var vx_int("x", DataType::Int(32, 8)); tir::Var vy_int("y", DataType::Int(32, 8)); - ICHECK(vpat.Match(vx + vy * tir::Broadcast(2.0f, 8))); - ICHECK(!vpat.Match(vx_int + vy_int * tir::Broadcast(2, 8))); + TVM_FFI_ICHECK(vpat.Match(vx + vy * tir::Broadcast(2.0f, 8))); + TVM_FFI_ICHECK(!vpat.Match(vx_int + vy_int * tir::Broadcast(2, 8))); } diff --git a/tests/cpp/random_engine_test.cc b/tests/cpp/random_engine_test.cc index 078f99bd6e90..42d65aa4027b 100644 --- a/tests/cpp/random_engine_test.cc +++ b/tests/cpp/random_engine_test.cc @@ -33,7 +33,7 @@ TEST(RandomEngine, Randomness) { covered[rng() % 100] = true; } for (int i = 0; i < 100; i++) { - ICHECK(covered[i]); + TVM_FFI_ICHECK(covered[i]); } } @@ -45,7 +45,7 @@ TEST(RandomEngine, Reproducibility) { rng_b.Seed(0x23456789); for (int i = 0; i < 100000; i++) { - ICHECK_EQ(rng_a(), rng_b()); + TVM_FFI_ICHECK_EQ(rng_a(), rng_b()); } } @@ -56,10 +56,10 @@ TEST(RandomEngine, Serialization) { rng_a.Seed(0x56728); rand_state_b = rand_state_a; - for (int i = 0; i < 100000; i++) ICHECK_EQ(rng_a(), rng_b()); + for (int i = 0; i < 100000; i++) TVM_FFI_ICHECK_EQ(rng_a(), rng_b()); for (int i = 0; i < 123456; i++) rng_a(); rand_state_b = rand_state_a; - for (int i = 0; i < 100000; i++) ICHECK_EQ(rng_a(), rng_b()); + for (int i = 0; i < 100000; i++) TVM_FFI_ICHECK_EQ(rng_a(), rng_b()); } diff --git a/tests/cpp/runtime/logging_test.cc b/tests/cpp/runtime/logging_test.cc index e707606843bf..ab52db5fd008 100644 --- a/tests/cpp/runtime/logging_test.cc +++ b/tests/cpp/runtime/logging_test.cc @@ -72,14 +72,14 @@ TEST(TvmLogDebugSettings, VLogEnabledComplex) { TEST(TvmLogDebugSettings, IllFormed) { MATCH_THROW( - TvmLogDebugSettings::ParseSpec("foo/bar.cc=bogus;"), InternalError, + TvmLogDebugSettings::ParseSpec("foo/bar.cc=bogus;"), tvm::ffi::Error, ::testing::HasSubstr("TVM_LOG_DEBUG ill-formed at position 11: invalid level: \"bogus;\"")); - MATCH_THROW(TvmLogDebugSettings::ParseSpec("DEFAULT=2;bar/baz.cc=2"), InternalError, + MATCH_THROW(TvmLogDebugSettings::ParseSpec("DEFAULT=2;bar/baz.cc=2"), tvm::ffi::Error, ::testing::HasSubstr( "TVM_LOG_DEBUG ill-formed at position 8: invalid level: \"2;bar/baz.cc=2\"")); - MATCH_THROW(TvmLogDebugSettings::ParseSpec("DEFAULT=2,bar/baz.cc+2"), InternalError, + MATCH_THROW(TvmLogDebugSettings::ParseSpec("DEFAULT=2,bar/baz.cc+2"), tvm::ffi::Error, ::testing::HasSubstr("TVM_LOG_DEBUG ill-formed at position 22: expecting " "\"=\" after \"bar/baz.cc+2\"")); } diff --git a/tests/cpp/support/ring_buffer_test.cc b/tests/cpp/support/ring_buffer_test.cc index 9b78b2767731..43a921e0d1b1 100644 --- a/tests/cpp/support/ring_buffer_test.cc +++ b/tests/cpp/support/ring_buffer_test.cc @@ -50,15 +50,15 @@ TEST(RingBuffer, ReadWithCallback) { auto callback0 = [](const char* data, size_t size) -> size_t { const int* iptr = reinterpret_cast(data); - ICHECK_EQ(iptr[0], 1); - ICHECK_EQ(iptr[1], 2); + TVM_FFI_ICHECK_EQ(iptr[0], 1); + TVM_FFI_ICHECK_EQ(iptr[1], 2); return size; }; buffer.ReadWithCallback(callback0, 2 * sizeof(int)); auto callback1 = [](const char* data, size_t size) -> size_t { const int* iptr = reinterpret_cast(data); - ICHECK_EQ(iptr[0], 3); - ICHECK_EQ(iptr[1], 4); + TVM_FFI_ICHECK_EQ(iptr[0], 3); + TVM_FFI_ICHECK_EQ(iptr[1], 4); return size; }; buffer.ReadWithCallback(callback1, 2 * sizeof(int)); diff --git a/tests/cpp/support/scalars_test.cc b/tests/cpp/support/scalars_test.cc index 12a5145f2145..bf4974254f6c 100644 --- a/tests/cpp/support/scalars_test.cc +++ b/tests/cpp/support/scalars_test.cc @@ -29,32 +29,30 @@ namespace { // Here we just check handling which is difficult to test via the standard Python API. TEST(Scalars, IntImmToTensor_Unsupported) { - ASSERT_THROW(IntImmToTensor(IntImm(DataType::Int(15), 42)), runtime::InternalError); + ASSERT_THROW(IntImmToTensor(IntImm(DataType::Int(15), 42)), tvm::ffi::Error); } TEST(Scalars, FloatImmtoTensor_Unsupported) { - ASSERT_THROW(FloatImmToTensor(FloatImm(DataType::Float(15), 42.0)), runtime::InternalError); + ASSERT_THROW(FloatImmToTensor(FloatImm(DataType::Float(15), 42.0)), tvm::ffi::Error); } TEST(Scalars, TensorScalarToString_Unsupported) { auto ndarray = runtime::Tensor::Empty({}, DataType::Int(8), {DLDeviceType::kDLCPU, 0}); - ASSERT_THROW(TensorScalarToString(ndarray), runtime::InternalError); + ASSERT_THROW(TensorScalarToString(ndarray), tvm::ffi::Error); } TEST(Scalars, IntImmToString_Unsupported) { - ASSERT_THROW(IntImmToString(IntImm(DataType::Int(15), 42)), runtime::InternalError); + ASSERT_THROW(IntImmToString(IntImm(DataType::Int(15), 42)), tvm::ffi::Error); } TEST(Scalars, FloatImmToString_Unsupported) { - ASSERT_THROW(FloatImmToString(FloatImm(DataType::Float(15), 42.0)), runtime::InternalError); + ASSERT_THROW(FloatImmToString(FloatImm(DataType::Float(15), 42.0)), tvm::ffi::Error); } -TEST(Scalars, ValueToIntImm_Unsupported) { - ASSERT_THROW(ValueToIntImm(42, 15), runtime::InternalError); -} +TEST(Scalars, ValueToIntImm_Unsupported) { ASSERT_THROW(ValueToIntImm(42, 15), tvm::ffi::Error); } TEST(SCalars, ValueToFloatImm_Unsupported) { - ASSERT_THROW(ValueToFloatImm(42.0, 15), runtime::InternalError); + ASSERT_THROW(ValueToFloatImm(42.0, 15), tvm::ffi::Error); } } // namespace diff --git a/tests/cpp/target/canonicalizer/arm_aprofile_test.cc b/tests/cpp/target/canonicalizer/arm_aprofile_test.cc index a30a4edc2574..73c0e6c29110 100644 --- a/tests/cpp/target/canonicalizer/arm_aprofile_test.cc +++ b/tests/cpp/target/canonicalizer/arm_aprofile_test.cc @@ -363,12 +363,12 @@ TEST_F(AProfileCanonicalizerTest, UnexpectedTargetKind) { { try { Canonicalize({{"kind", ffi::String("c")}}); - } catch (const tvm::InternalError& e) { + } catch (const tvm::ffi::Error& e) { EXPECT_THAT(e.what(), HasSubstr("Expected target kind 'llvm', but got 'c'")); throw; } }, - tvm::InternalError); + tvm::ffi::Error); } TEST(AProfileCanonicalizerInvalid, LLVMUnsupportedArchitecture) { diff --git a/tests/cpp/target_test.cc b/tests/cpp/target_test.cc index fd1301c79442..df592f760866 100644 --- a/tests/cpp/target_test.cc +++ b/tests/cpp/target_test.cc @@ -69,7 +69,7 @@ TEST(TargetKind, GetAttrMap) { auto map = tvm::TargetKind::GetAttrMap("Attr1"); auto target_kind = tvm::TargetKind::Get("TestTargetKind").value(); std::string result = map[target_kind]; - ICHECK_EQ(result, "Value1"); + TVM_FFI_ICHECK_EQ(result, "Value1"); } TEST(TargetCreation, NestedConfig) { @@ -86,21 +86,21 @@ TEST(TargetCreation, NestedConfig) { }, }; Target target = Target(config); - ICHECK_EQ(target->kind, TargetKind::Get("TestTargetKind").value()); - ICHECK_EQ(target->tag, ""); - ICHECK(target->keys.empty()); + TVM_FFI_ICHECK_EQ(target->kind, TargetKind::Get("TestTargetKind").value()); + TVM_FFI_ICHECK_EQ(target->tag, ""); + TVM_FFI_ICHECK(target->keys.empty()); bool my_bool = target->GetAttr("my_bool").value(); - ICHECK_EQ(my_bool, true); + TVM_FFI_ICHECK_EQ(my_bool, true); ffi::Array your_names = target->GetAttr>("your_names").value(); - ICHECK_EQ(your_names.size(), 2U); - ICHECK_EQ(your_names[0], "junru"); - ICHECK_EQ(your_names[1], "jian"); + TVM_FFI_ICHECK_EQ(your_names.size(), 2U); + TVM_FFI_ICHECK_EQ(your_names[0], "junru"); + TVM_FFI_ICHECK_EQ(your_names[1], "jian"); ffi::Map her_maps = target->GetAttr>("her_maps").value(); - ICHECK_EQ(her_maps.size(), 2U); - ICHECK_EQ(her_maps["a"], 1); - ICHECK_EQ(her_maps["b"], 2); + TVM_FFI_ICHECK_EQ(her_maps.size(), 2U); + TVM_FFI_ICHECK_EQ(her_maps["a"], 1); + TVM_FFI_ICHECK_EQ(her_maps["b"], 2); } TEST(TargetCreationFail, UnrecognizedConfigOption) { @@ -485,7 +485,7 @@ TEST(TargetCreation, DetectSystemTriple) { }; Target target = Target(config); - ICHECK_EQ(target->kind, TargetKind::Get("llvm").value()); + TVM_FFI_ICHECK_EQ(target->kind, TargetKind::Get("llvm").value()); auto pf = tvm::ffi::Function::GetGlobal("target.llvm_get_system_triple"); if (!pf.has_value()) { @@ -505,23 +505,23 @@ TEST(TargetCreation, DeduplicateKeys) { {"device", ffi::String("arm_cpu")}, }; Target target = Target(config); - ICHECK_EQ(target->kind, TargetKind::Get("llvm").value()); - ICHECK_EQ(target->tag, ""); - ICHECK_EQ(target->keys.size(), 2U); - ICHECK_EQ(target->keys[0], "cpu"); - ICHECK_EQ(target->keys[1], "arm_cpu"); - ICHECK_EQ(target->attrs.size(), 2U); - ICHECK_EQ(target->GetAttr("device"), "arm_cpu"); + TVM_FFI_ICHECK_EQ(target->kind, TargetKind::Get("llvm").value()); + TVM_FFI_ICHECK_EQ(target->tag, ""); + TVM_FFI_ICHECK_EQ(target->keys.size(), 2U); + TVM_FFI_ICHECK_EQ(target->keys[0], "cpu"); + TVM_FFI_ICHECK_EQ(target->keys[1], "arm_cpu"); + TVM_FFI_ICHECK_EQ(target->attrs.size(), 2U); + TVM_FFI_ICHECK_EQ(target->GetAttr("device"), "arm_cpu"); } TEST(TargetKindRegistry, ListTargetKinds) { ffi::Array names = TargetKindRegEntry::ListTargetKinds(); - ICHECK_EQ(names.empty(), false); - ICHECK_EQ(std::count(std::begin(names), std::end(names), "llvm"), 1); + TVM_FFI_ICHECK_EQ(names.empty(), false); + TVM_FFI_ICHECK_EQ(std::count(std::begin(names), std::end(names), "llvm"), 1); } TEST(TargetKindRegistry, ListTargetOptions) { TargetKind llvm = TargetKind::Get("llvm").value(); ffi::Map attrs = TargetKindRegEntry::ListTargetKindOptions(llvm); - ICHECK_EQ(attrs.empty(), false); + TVM_FFI_ICHECK_EQ(attrs.empty(), false); } diff --git a/tests/cpp/tir_analysis_side_effect.cc b/tests/cpp/tir_analysis_side_effect.cc index 12c011fd6abb..7ac19c28f198 100644 --- a/tests/cpp/tir_analysis_side_effect.cc +++ b/tests/cpp/tir_analysis_side_effect.cc @@ -27,8 +27,9 @@ TEST(SimplePasses, SideEffect) { using namespace tvm; auto buf = tir::decl_buffer({16}, DataType::Float(32)); auto i = tir::Var("i", DataType::Int(32)); - ICHECK(tir::SideEffect(tir::BufferLoad(buf, {i})) == tir::CallEffectKind::kReadState); - ICHECK(tir::SideEffect(exp(tir::Cast(DataType::Float(32), i + 1))) == tir::CallEffectKind::kPure); - ICHECK(tir::SideEffect(tir::Call(DataType::Handle(), tir::builtin::tvm_storage_sync(), {})) == - tir::CallEffectKind::kUpdateState); + TVM_FFI_ICHECK(tir::SideEffect(tir::BufferLoad(buf, {i})) == tir::CallEffectKind::kReadState); + TVM_FFI_ICHECK(tir::SideEffect(exp(tir::Cast(DataType::Float(32), i + 1))) == + tir::CallEffectKind::kPure); + TVM_FFI_ICHECK(tir::SideEffect(tir::Call(DataType::Handle(), tir::builtin::tvm_storage_sync(), + {})) == tir::CallEffectKind::kUpdateState); } diff --git a/tests/cpp/tir_scalable_datatype.cc b/tests/cpp/tir_scalable_datatype.cc index 6ae6deb50d2e..f8f453bdbc80 100644 --- a/tests/cpp/tir_scalable_datatype.cc +++ b/tests/cpp/tir_scalable_datatype.cc @@ -112,13 +112,13 @@ TEST(ScalableDataType, TestGetScalableVectorBytes) { { try { tvm::runtime::GetVectorBytes(scalable_type); - } catch (const tvm::InternalError& e) { + } catch (const tvm::ffi::Error& e) { EXPECT_THAT(e.what(), HasSubstr("Can't fetch the lanes of a scalable vector at a compile time")); throw; } }, - tvm::InternalError); + tvm::ffi::Error); } TEST(ScalableDataType, TestScalableDataTypeInvalidLanesError) { @@ -126,12 +126,12 @@ TEST(ScalableDataType, TestScalableDataTypeInvalidLanesError) { { try { tvm::DataType(kDLFloat, 62, 1, true); - } catch (const tvm::InternalError& e) { + } catch (const tvm::ffi::Error& e) { EXPECT_THAT(e.what(), HasSubstr("Invalid value for vscale factor")); throw; } }, - tvm::InternalError); + tvm::ffi::Error); } TEST(ScalableDataType, TestScalableDataTypeInvalidVscaleFactorAccess) { @@ -142,12 +142,12 @@ TEST(ScalableDataType, TestScalableDataTypeInvalidVscaleFactorAccess) { { try { fixed_length_type.vscale_factor(); - } catch (const tvm::InternalError& e) { + } catch (const tvm::ffi::Error& e) { EXPECT_THAT(e.what(), HasSubstr("A fixed length vector doesn't have a vscale factor")); throw; } }, - tvm::InternalError); + tvm::ffi::Error); } TEST(ScalableDataType, TestScalableDataTypeInvalidLanesAccess) { @@ -156,13 +156,13 @@ TEST(ScalableDataType, TestScalableDataTypeInvalidLanesAccess) { { try { scalable_type.lanes(); - } catch (const tvm::InternalError& e) { + } catch (const tvm::ffi::Error& e) { EXPECT_THAT(e.what(), HasSubstr("Can't fetch the lanes of a scalable vector at a compile time")); throw; } }, - tvm::InternalError); + tvm::ffi::Error); } TEST(ScalableDataType, TestScalableBool) { diff --git a/tests/python/relax/frontend_nn_extern_module.cc b/tests/python/relax/frontend_nn_extern_module.cc index 1bac39b35091..0f2f2dda8295 100644 --- a/tests/python/relax/frontend_nn_extern_module.cc +++ b/tests/python/relax/frontend_nn_extern_module.cc @@ -28,12 +28,12 @@ namespace { int _scalar_add(DLTensor* a, DLTensor* b, DLTensor* c) { using namespace tvm::runtime; - ICHECK(a->ndim == 0); - ICHECK(b->ndim == 0); - ICHECK(c->ndim == 0); - ICHECK(DataType(a->dtype) == DataType::Float(32)); - ICHECK(DataType(b->dtype) == DataType::Float(32)); - ICHECK(DataType(c->dtype) == DataType::Float(32)); + TVM_FFI_ICHECK(a->ndim == 0); + TVM_FFI_ICHECK(b->ndim == 0); + TVM_FFI_ICHECK(c->ndim == 0); + TVM_FFI_ICHECK(DataType(a->dtype) == DataType::Float(32)); + TVM_FFI_ICHECK(DataType(b->dtype) == DataType::Float(32)); + TVM_FFI_ICHECK(DataType(c->dtype) == DataType::Float(32)); float* a_data = static_cast(a->data); float* b_data = static_cast(b->data); float* c_data = static_cast(c->data); @@ -43,25 +43,25 @@ int _scalar_add(DLTensor* a, DLTensor* b, DLTensor* c) { int _test_sym(DLTensor* a, DLTensor* b, DLTensor* c) { using namespace tvm::runtime; - ICHECK(a->ndim == 3); // [x, y, 1] - ICHECK(b->ndim == 3); // [y, z, 5] - ICHECK(c->ndim == 4); // [x, y, z, 9] - ICHECK(DataType(a->dtype) == DataType::Float(32)); - ICHECK(DataType(b->dtype) == DataType::Float(32)); - ICHECK(DataType(c->dtype) == DataType::Float(32)); + TVM_FFI_ICHECK(a->ndim == 3); // [x, y, 1] + TVM_FFI_ICHECK(b->ndim == 3); // [y, z, 5] + TVM_FFI_ICHECK(c->ndim == 4); // [x, y, z, 9] + TVM_FFI_ICHECK(DataType(a->dtype) == DataType::Float(32)); + TVM_FFI_ICHECK(DataType(b->dtype) == DataType::Float(32)); + TVM_FFI_ICHECK(DataType(c->dtype) == DataType::Float(32)); int x = a->shape[0]; int y = a->shape[1]; int z = b->shape[1]; - ICHECK(a->shape[0] == x); - ICHECK(a->shape[1] == y); - ICHECK(a->shape[2] == 1); - ICHECK(b->shape[0] == y); - ICHECK(b->shape[1] == z); - ICHECK(b->shape[2] == 5); - ICHECK(c->shape[0] == x); - ICHECK(c->shape[1] == y); - ICHECK(c->shape[2] == z); - ICHECK(c->shape[3] == 9); + TVM_FFI_ICHECK(a->shape[0] == x); + TVM_FFI_ICHECK(a->shape[1] == y); + TVM_FFI_ICHECK(a->shape[2] == 1); + TVM_FFI_ICHECK(b->shape[0] == y); + TVM_FFI_ICHECK(b->shape[1] == z); + TVM_FFI_ICHECK(b->shape[2] == 5); + TVM_FFI_ICHECK(c->shape[0] == x); + TVM_FFI_ICHECK(c->shape[1] == y); + TVM_FFI_ICHECK(c->shape[2] == z); + TVM_FFI_ICHECK(c->shape[3] == 9); return 0; } } // namespace diff --git a/web/emcc/tvmjs_support.cc b/web/emcc/tvmjs_support.cc index c502dbff2b72..6b559f628e93 100644 --- a/web/emcc/tvmjs_support.cc +++ b/web/emcc/tvmjs_support.cc @@ -156,7 +156,7 @@ class AsyncLocalSession : public LocalSession { }); } else { // for exception, we can pass through as since this is just normal encoding. - ICHECK_EQ(code, static_cast(RPCCode::kException)); + TVM_FFI_ICHECK_EQ(code, static_cast(RPCCode::kException)); callback(RPCCode::kException, args); } }); @@ -178,7 +178,7 @@ class AsyncLocalSession : public LocalSession { rv = retfunc; this->EncodeReturn(std::move(rv), [&](ffi::PackedArgs encoded_args) { const void* pf = encoded_args[0].as(); - ICHECK(pf != nullptr); + TVM_FFI_ICHECK(pf != nullptr); // mark as async. async_func_set_.insert(const_cast(pf)); callback(RPCCode::kReturn, encoded_args); @@ -233,11 +233,11 @@ class AsyncLocalSession : public LocalSession { packed_args[0] = nullptr; on_complete(RPCCode::kReturn, ffi::PackedArgs(packed_args, 1)); } else { - CHECK(dev.device_type == static_cast(kDLWebGPU)); + TVM_FFI_ICHECK(dev.device_type == static_cast(kDLWebGPU)); if (!async_wait_.has_value()) { async_wait_ = tvm::ffi::Function::GetGlobal("__async.wasm.WebGPUWaitForTasks"); } - CHECK(async_wait_.has_value()); + TVM_FFI_ICHECK(async_wait_.has_value()); ffi::Function packed_callback([on_complete](ffi::PackedArgs args, ffi::Any*) { int code = args[0].cast(); on_complete(static_cast(code), args.Slice(1)); @@ -270,7 +270,7 @@ class AsyncLocalSession : public LocalSession { repeats_to_cooldown); } else { auto pf = tvm::ffi::Function::GetGlobal(name); - CHECK(pf.has_value()) << "Cannot find " << name << " in the global function"; + TVM_FFI_ICHECK(pf.has_value()) << "Cannot find " << name << " in the global function"; return WrapWasmTimeEvaluator(*pf, dev, number, repeat, min_repeat_ms, limit_zero_time_iterations, cooldown_interval_ms, repeats_to_cooldown); @@ -295,7 +295,7 @@ class AsyncLocalSession : public LocalSession { } }; auto time_exec = tvm::ffi::Function::GetGlobal("__async.wasm.TimeExecution"); - CHECK(time_exec.has_value()) << "Cannot find wasm.GetTimer in the global function"; + TVM_FFI_ICHECK(time_exec.has_value()) << "Cannot find wasm.GetTimer in the global function"; (*time_exec)(ffi::TypedFunction(finvoke), dev, number, repeat, min_repeat_ms, limit_zero_time_iterations, cooldown_interval_ms, repeats_to_cooldown, /*cache_flush_bytes=*/0, on_complete); diff --git a/web/emcc/wasm_runtime.cc b/web/emcc/wasm_runtime.cc index a852c36a5a8e..5ab3e4928b96 100644 --- a/web/emcc/wasm_runtime.cc +++ b/web/emcc/wasm_runtime.cc @@ -126,18 +126,18 @@ TVM_FFI_STATIC_INIT_BLOCK() { void ArrayDecodeStorage(Tensor cpu_arr, TVMFFIByteArray* bytes, const std::string& format, const std::string& dtype) { - ICHECK_NE(bytes, nullptr); + TVM_FFI_ICHECK_NE(bytes, nullptr); const char* byte_data = bytes->data; const size_t byte_size = bytes->size; if (format == "f32-to-bf16" && dtype == "float32") { const uint16_t* bf16 = reinterpret_cast(byte_data); uint32_t* data = static_cast(cpu_arr->data); - ICHECK(cpu_arr.IsContiguous()); + TVM_FFI_ICHECK(cpu_arr.IsContiguous()); size_t size = 1; for (int i = 0; i < cpu_arr->ndim; ++i) { size *= cpu_arr->shape[i]; } - ICHECK_EQ(size, byte_size / 2); + TVM_FFI_ICHECK_EQ(size, byte_size / 2); for (size_t i = 0; i < size; ++i) { data[i] = static_cast(bf16[i]) << 16; } @@ -167,7 +167,7 @@ TVM_FFI_STATIC_INIT_BLOCK() { for (int i = 0; i < args.size(); ++i) { // Get i-th TVMArray auto* arr_i = args[i].as(); - ICHECK(arr_i != nullptr); + TVM_FFI_ICHECK(arr_i != nullptr); for (size_t j = 0; j < arr_i->size(); ++j) { // Push back each j-th element of the i-th array data.push_back(arr_i->at(j)); @@ -184,8 +184,8 @@ Tensor ConcatEmbeddings(const std::vector& embeddings) { DLDevice device = embeddings[0]->device; int seqLen = 0; for (int i = 0; i < embeddings.size(); ++i) { - ICHECK_EQ(embeddings[i]->ndim, 2); - ICHECK_EQ(embeddings[i]->shape[1], hidden_size); + TVM_FFI_ICHECK_EQ(embeddings[i]->ndim, 2); + TVM_FFI_ICHECK_EQ(embeddings[i]->shape[1], hidden_size); seqLen += embeddings[i]->shape[0]; } diff --git a/web/emcc/webgpu_runtime.cc b/web/emcc/webgpu_runtime.cc index c8bfc4d81b7e..7471ad592e20 100644 --- a/web/emcc/webgpu_runtime.cc +++ b/web/emcc/webgpu_runtime.cc @@ -60,7 +60,7 @@ class WebGPUDeviceAPI : public DeviceAPI { public: WebGPUDeviceAPI() { auto fp = tvm::ffi::Function::GetGlobal("wasm.WebGPUDeviceAPI"); - CHECK(fp.has_value()) << "Cannot find wasm.WebGPUContext in the env"; + TVM_FFI_ICHECK(fp.has_value()) << "Cannot find wasm.WebGPUContext in the env"; auto getter = ffi::TypedFunction(*fp); alloc_space_ = getter("deviceAllocDataSpace"); free_space_ = getter("deviceFreeDataSpace"); @@ -89,7 +89,7 @@ class WebGPUDeviceAPI : public DeviceAPI { TVMStreamHandle stream) final { if (static_cast(dev_from.device_type) == kDLWebGPU && static_cast(dev_to.device_type) == kDLWebGPU) { - CHECK_EQ(dev_from.device_id, dev_to.device_id); + TVM_FFI_ICHECK_EQ(dev_from.device_id, dev_to.device_id); copy_within_gpu_(const_cast(from), from_offset, to, to_offset, size); } else if (static_cast(dev_from.device_type) == kDLWebGPU && dev_to.device_type == kDLCPU) { @@ -100,22 +100,26 @@ class WebGPUDeviceAPI : public DeviceAPI { void* from_ptr = static_cast(const_cast(from)) + from_offset; copy_to_gpu_(from_ptr, to, to_offset, size); } else { - LOG(FATAL) << "expect copy from/to WebGPU or between WebGPU"; + TVM_FFI_THROW(InternalError) << "expect copy from/to WebGPU or between WebGPU"; } } public: - TVMStreamHandle CreateStream(Device dev) final { LOG(FATAL) << "Not implemented"; } + TVMStreamHandle CreateStream(Device dev) final { + TVM_FFI_THROW(InternalError) << "Not implemented"; + } - void FreeStream(Device dev, TVMStreamHandle stream) final { LOG(FATAL) << "Not implemented"; } + void FreeStream(Device dev, TVMStreamHandle stream) final { + TVM_FFI_THROW(InternalError) << "Not implemented"; + } void SyncStreamFromTo(Device dev, TVMStreamHandle event_src, TVMStreamHandle event_dst) { - LOG(FATAL) << "Not implemented"; + TVM_FFI_THROW(InternalError) << "Not implemented"; } void StreamSync(Device dev, TVMStreamHandle stream) final { static auto func = tvm::ffi::Function::GetGlobal("__asyncify.WebGPUWaitForTasks"); - ICHECK(func.has_value()) << "Stream sync inside c++ only supported in asyncify mode"; + TVM_FFI_ICHECK(func.has_value()) << "Stream sync inside c++ only supported in asyncify mode"; (*func)(); } @@ -158,7 +162,7 @@ class WebGPUModuleNode final : public ffi::ModuleObj { ffi::Map fmap) : smap_(smap), fmap_(fmap) { auto fp = tvm::ffi::Function::GetGlobal("wasm.WebGPUCreateShader"); - CHECK(fp.has_value()); + TVM_FFI_ICHECK(fp.has_value()); create_shader_ = *fp; } @@ -179,7 +183,7 @@ class WebGPUModuleNode final : public ffi::ModuleObj { return ffi::Function([this](ffi::PackedArgs args, ffi::Any* rv) { auto name = args[0].cast(); auto it = smap_.find(name); - ICHECK(it != smap_.end()) << "Cannot find code " << name; + TVM_FFI_ICHECK(it != smap_.end()) << "Cannot find code " << name; *rv = it->second; }); } else if (name == "webgpu.update_prebuild") { @@ -198,7 +202,7 @@ class WebGPUModuleNode final : public ffi::ModuleObj { auto it = smap_.find(name); if (it != smap_.end()) { auto opt_info = fmap_.Get(name); - ICHECK(opt_info.has_value()); + TVM_FFI_ICHECK(opt_info.has_value()); FunctionInfo orig_info = opt_info.value(); FunctionInfo info(name, orig_info->arg_types, orig_info->launch_param_tags, orig_info->arg_extra_tags); @@ -212,7 +216,7 @@ class WebGPUModuleNode final : public ffi::ModuleObj { int GetPropertyMask() const final { return ffi::Module::kBinarySerializable; }; - ffi::Bytes SaveToBytes() const final { LOG(FATAL) << "Not implemented"; } + ffi::Bytes SaveToBytes() const final { TVM_FFI_THROW(InternalError) << "Not implemented"; } ffi::String InspectSource(const ffi::String& format) const final { // can only return source code. @@ -237,7 +241,7 @@ ffi::Module WebGPUModuleLoadFromBytes(const ffi::Bytes& bytes) { std::unordered_map smap; ffi::Map fmap; - ICHECK(stream.Read(&fmap)); + TVM_FFI_ICHECK(stream.Read(&fmap)); stream.Read(&smap); return ffi::Module(ffi::make_object(smap, fmap)); }