[ARITH] Add optional Z3-backed proving to Analyzer#19666
Conversation
Add an optional Z3 SMT solver backend to tvm::arith::Analyzer for stronger integer arithmetic proving. The integration is guarded by a new USE_Z3 CMake option (default OFF). When enabled, Analyzer::CanProve runs the existing analysis path first and only falls back to Z3 when the existing analyzers cannot prove the predicate. When disabled, a stub implementation keeps the C++ and Python APIs available without Z3.
There was a problem hiding this comment.
Code Review
This pull request integrates Z3 SMT solver support into TVM's arithmetic analyzer, adding a new Z3Prover sub-analyzer, exposing its APIs to Python, and adding unit tests. The code review highlighted several critical issues in the Z3 prover implementation, including thread-safety risks when copying the thread-local Z3 context, potential compiler crashes from unsafe downcasting and LOG(FATAL) calls on unsupported nodes, incorrect handling of z3::unknown solver results, mathematical bugs due to mapping TVM's truncated division/modulo to Z3's floor division/modulo, and solver state corruption from permanently adding shift constraints. Minor improvements were also suggested to resolve an implicit double-to-unsigned conversion warning and fix a typo.
Important
The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.
| solver = CreateSolver(*other_.ctx); | ||
| // 2. copy the context | ||
| // the context is a shared_ptr, we can just copy the pointer | ||
| ctx = other_.ctx; |
There was a problem hiding this comment.
Assigning to ctx (which is a thread_local static variable) inside CopyFrom is extremely dangerous. It overwrites the thread-local static context for the entire current thread with the source thread's context. If other_ was created on a different thread, this violates Z3's thread-safety rules (sharing a context across threads) and will cause data races or crashes. Even on the same thread, it is redundant. Since ctx is thread-local static, we should not reassign it in CopyFrom.
// Do not copy the thread-local static context pointer to avoid cross-thread safety issues.| } else { | ||
| auto min_val = Downcast<IntImm>(min_value(dtype))->value; | ||
| auto max_val = Downcast<IntImm>(max_value(dtype))->value; | ||
| solver.add(ctx->int_val(min_val) <= e && e <= ctx->int_val(max_val)); | ||
| } |
There was a problem hiding this comment.
Downcasting min_value(dtype) and max_value(dtype) to IntImm without checking if dtype is actually an integer or boolean type will cause a compiler crash when encountering variables or loads of other types (like float32 or handle). For non-integer/non-boolean types, min_value returns a FloatImm or other non-IntImm expression, causing Downcast<IntImm> to fail. We should guard this block to only apply to integer and boolean types.
} else if (dtype.is_int() || dtype.is_uint()) {\n auto min_val = Downcast<IntImm>(min_value(dtype))->value;\n auto max_val = Downcast<IntImm>(max_value(dtype))->value;\n solver.add(ctx->int_val(min_val) <= e && e <= ctx->int_val(max_val));\n }| z3::expr VisitExprDefault_(const Object* op) override { | ||
| LOG(FATAL) << "Z3Prover only support integers, but got " << op->GetTypeKey() << "."; | ||
| TVM_FFI_UNREACHABLE(); | ||
| } |
There was a problem hiding this comment.
Using LOG(FATAL) to crash the compiler when encountering an unsupported node type is unsafe for a fallback analyzer. During optimization passes, the analyzer can be called on arbitrary expressions. It should fail gracefully by returning a free variable (which cannot be proven true/false), allowing the analyzer to return false instead of crashing the entire process.
z3::expr VisitExprDefault_(const Object* op) override {\n return Create(static_cast<const PrimExprNode*>(op));\n }| auto result = solver.check(); | ||
| if (result != z3::sat) { | ||
| break; // No more solutions |
There was a problem hiding this comment.
If solver.check() returns z3::unknown (e.g., due to timeout or resource limit), the loop simply breaks and returns the partial count found so far. This is incorrect and misleading, as the caller might assume the count is complete. It should return a negative error code (like -1 or -3) to indicate incomplete/unknown results.
auto result = solver.check();\n if (result == z3::unknown) {\n solver.pop();\n solver.set(\"model\", false);\n return -1;\n } else if (result != z3::sat) {\n break; // No more solutions\n }| z3::expr VisitExpr_(const DivNode* op) override { | ||
| return VisitArith(z3::operator/, op, op->a, op->b); | ||
| } | ||
| z3::expr VisitExpr_(const ModNode* op) override { | ||
| return VisitArith(z3::operator%, op, op->a, op->b); | ||
| } |
There was a problem hiding this comment.
DivNode and ModNode represent truncated division and modulo in TVM, but they are currently translated directly to Z3's / and % operators, which represent floor division and modulo. This leads to incorrect proofs for negative numbers. They should be translated using a correct truncated division/modulo formula.
z3::expr VisitExpr_(const DivNode* op) override {\n if (IsValidDType(op->a->dtype) && IsValidDType(op->b->dtype)) {\n z3::expr a = VisitInt(op->a);\n z3::expr b = VisitInt(op->b);\n return z3::ite(a >= 0,\n z3::ite(b >= 0, a / b, -(a / -b)),\n z3::ite(b >= 0, -(-a / b), -a / -b));\n } else {\n return Create(op);\n }\n }\n z3::expr VisitExpr_(const ModNode* op) override {\n if (IsValidDType(op->a->dtype) && IsValidDType(op->b->dtype)) {\n z3::expr a = VisitInt(op->a);\n z3::expr b = VisitInt(op->b);\n z3::expr div_res = z3::ite(a >= 0,\n z3::ite(b >= 0, a / b, -(a / -b)),\n z3::ite(b >= 0, -(-a / b), -a / -b));\n return a - div_res * b;\n } else {\n return Create(op);\n }\n }| // For shift operations, we need to ensure the shift amount is non-negative | ||
| // and within reasonable bounds | ||
| z3::expr a_expr = VisitInt(a); | ||
| z3::expr b_expr = VisitInt(b); | ||
|
|
||
| // Add constraint that shift amount should be non-negative | ||
| // This is a common assumption in many programming languages | ||
| solver.add(b_expr >= 0); | ||
|
|
||
| // Also limit shift amount to avoid unrealistic large shifts | ||
| // We'll limit to 64 bits (reasonable for most use cases) | ||
| solver.add(b_expr < 64); | ||
|
|
||
| unsigned bit_width = std::max(a.dtype().bits(), b.dtype().bits()); | ||
| z3::expr a_bv = z3::int2bv(bit_width, a_expr); | ||
| z3::expr b_bv = z3::int2bv(bit_width, b_expr); | ||
|
|
||
| // Perform the shift in bit-vector domain, then cast back to int. | ||
| z3::expr result_bv = op_func(a_bv, b_bv); | ||
| return z3::bv2int(result_bv, true); | ||
| } else { | ||
| return Create(op); |
There was a problem hiding this comment.
Permanently adding b_expr >= 0 and b_expr < 64 to the global solver during expression translation (which happens inside CanProve without a push/pop) corrupts the solver state for all subsequent queries. Since Z3's bit-vector shift operations already handle out-of-bounds shift amounts correctly, these constraints are unnecessary and should be removed.
if (IsValidDType(a->dtype) && IsValidDType(b->dtype)) {\n z3::expr a_expr = VisitInt(a);\n z3::expr b_expr = VisitInt(b);\n\n unsigned bit_width = std::max(a.dtype().bits(), b.dtype().bits());\n z3::expr a_bv = z3::int2bv(bit_width, a_expr);\n z3::expr b_bv = z3::int2bv(bit_width, b_expr);\n\n // Perform the shift in bit-vector domain, then cast back to int.\n z3::expr result_bv = op_func(a_bv, b_bv);\n return z3::bv2int(result_bv, true);\n } else {| // Z3's implementation of timeout, when setting timeout T ms, it will stop at T - 1 ms | ||
| // SetTimeoutMs(5); | ||
| // use rlimit, not timeout to ensure determinstic behavior | ||
| SetRLimit(1e4); |
| /// @brief Check trivil bad cases, return true if the expr is a bad case | ||
| /// Z3 prover may take a long time to initialize (at least 200us), | ||
| /// This optimization can speedup 30% of the test cases in our unit tests | ||
| bool CheckTrivilBadCases(const PrimExpr& expr) { |
Summary
This PR adds an optional Z3 SMT solver backend to
tvm::arith::Analyzerfor stronger integer arithmetic proving.The integration is disabled by default and guarded by a new
USE_Z3CMake option. When enabled,Analyzer::CanProvekeeps the existing TVM arithmetic analysis path first, then falls back to Z3 only when the existing analyzers cannot prove the predicate. When disabled, a stub implementation is built so the C++ and Python APIs remain available without requiring Z3.Features
This PR adds:
USE_Z3, defaulting toOFF.arith::Z3Proversub-analyzer owned byarith::Analyzer.Analyzer::CanProveflow: it first runs the existing TVM logic (simplify, const-int-bound, int-set, transitive comparison) unchanged; only if that fails does it askz3_prover.CanProve(simplified), which proves the predicate by checking that its negation is unsatisfiable under the current constraints. Z3 is a pure fallback and a negative/unknown result keeps the originalfalse.Analyzer.get_smtlib2(expr=None)Analyzer.set_z3_timeout_ms(timeout_ms)Analyzer.set_z3_rlimit(rlimit)Analyzer.get_z3_stats()Z3Prover::CanProveZ3Prover::BindZ3Prover::EnterConstraintZ3Prover::GetSMTLIB2Z3Prover::GetStatsZ3Prover::SetTimeoutMsZ3Prover::SetRLimitZ3Prover::GetModelZ3Prover::CountSatisfyingValuesmin,max,select,let, casts, floor division, floor modulo, and selected bitwise/shift operations.rlimit, withrandom_seedfixed to42.TileLang References
The implementation is based on the Z3 analyzer integration used in TileLang's TVM fork, with the upstream port kept optional and scoped to TVM's arithmetic analyzer.
Differences From TileLang
The core
Z3Proveralgorithm is kept close to the latest TileLang TVM implementation. The differences are in how it is integrated into upstream TVM:Optional build. We make Z3 optional, defaulting
USE_Z3=OFFand building az3_prover_off.ccstub when disabled, so Z3 stays out of the default build. In TileLang it is enabled by default.Z3 discovery. We add
cmake/modules/contrib/Z3.cmake, which prefers a system Z3 and falls back to the PyPIz3-solverpackage. TileLang resolves Z3 fromz3-solverdirectly.is_assumepropagation. We forwardis_assumefromAnalyzer::EnterConstraintinto the Z3 prover so it handles assumption scopes the same way as the rewrite simplifier. The previous implementation in TileLang does not forward this flag.