Skip to content

[ARITH] Add optional Z3-backed proving to Analyzer#19666

Closed
Ubospica wants to merge 1 commit into
apache:mainfrom
Ubospica:2026-06-03-name
Closed

[ARITH] Add optional Z3-backed proving to Analyzer#19666
Ubospica wants to merge 1 commit into
apache:mainfrom
Ubospica:2026-06-03-name

Conversation

@Ubospica

@Ubospica Ubospica commented Jun 3, 2026

Copy link
Copy Markdown
Contributor

Summary

This PR adds an optional Z3 SMT solver backend to tvm::arith::Analyzer for stronger integer arithmetic proving.

The integration is disabled by default and guarded by a new USE_Z3 CMake option. When enabled, Analyzer::CanProve keeps the existing TVM arithmetic analysis path first, then falls back to Z3 only when the existing analyzers cannot prove the predicate. When disabled, a stub implementation is built so the C++ and Python APIs remain available without requiring Z3.

Features

This PR adds:

  • Optional Z3 build support through USE_Z3, defaulting to OFF.
  • A new arith::Z3Prover sub-analyzer owned by arith::Analyzer.
  • Updated Analyzer::CanProve flow: it first runs the existing TVM logic (simplify, const-int-bound, int-set, transitive comparison) unchanged; only if that fails does it ask z3_prover.CanProve(simplified), which proves the predicate by checking that its negation is unsatisfiable under the current constraints. Z3 is a pure fallback and a negative/unknown result keeps the original false.
  • SMT-LIB2 export for debugging and external solver reproduction.
  • Python APIs:
    • Analyzer.get_smtlib2(expr=None)
    • Analyzer.set_z3_timeout_ms(timeout_ms)
    • Analyzer.set_z3_rlimit(rlimit)
    • Analyzer.get_z3_stats()
  • C++ APIs:
    • Z3Prover::CanProve
    • Z3Prover::Bind
    • Z3Prover::EnterConstraint
    • Z3Prover::GetSMTLIB2
    • Z3Prover::GetStats
    • Z3Prover::SetTimeoutMs
    • Z3Prover::SetRLimit
    • Z3Prover::GetModel
    • Z3Prover::CountSatisfyingValues
  • Scalar integer, unsigned integer, and boolean expression translation to Z3.
  • Support for arithmetic, comparisons, boolean operators, min, max, select, let, casts, floor division, floor modulo, and selected bitwise/shift operations.
  • Deterministic solver control using Z3 rlimit, with random_seed fixed to 42.
  • Thread-local Z3 context sharing to reduce initialization overhead while keeping thread safety.
  • A disabled-mode stub implementation that returns conservative results when Z3 is not built.

TileLang References

The implementation is based on the Z3 analyzer integration used in TileLang's TVM fork, with the upstream port kept optional and scoped to TVM's arithmetic analyzer.

Differences From TileLang

The core Z3Prover algorithm is kept close to the latest TileLang TVM implementation. The differences are in how it is integrated into upstream TVM:

  1. Optional build. We make Z3 optional, defaulting USE_Z3=OFF and building a z3_prover_off.cc stub when disabled, so Z3 stays out of the default build. In TileLang it is enabled by default.

  2. Z3 discovery. We add cmake/modules/contrib/Z3.cmake, which prefers a system Z3 and falls back to the PyPI z3-solver package. TileLang resolves Z3 from z3-solver directly.

  3. is_assume propagation. We forward is_assume from Analyzer::EnterConstraint into the Z3 prover so it handles assumption scopes the same way as the rewrite simplifier. The previous implementation in TileLang does not forward this flag.

Add an optional Z3 SMT solver backend to tvm::arith::Analyzer for
stronger integer arithmetic proving. The integration is guarded by a
new USE_Z3 CMake option (default OFF). When enabled, Analyzer::CanProve
runs the existing analysis path first and only falls back to Z3 when the
existing analyzers cannot prove the predicate. When disabled, a stub
implementation keeps the C++ and Python APIs available without Z3.
@Ubospica Ubospica closed this Jun 3, 2026
@Ubospica Ubospica deleted the 2026-06-03-name branch June 3, 2026 20:49

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request integrates Z3 SMT solver support into TVM's arithmetic analyzer, adding a new Z3Prover sub-analyzer, exposing its APIs to Python, and adding unit tests. The code review highlighted several critical issues in the Z3 prover implementation, including thread-safety risks when copying the thread-local Z3 context, potential compiler crashes from unsafe downcasting and LOG(FATAL) calls on unsupported nodes, incorrect handling of z3::unknown solver results, mathematical bugs due to mapping TVM's truncated division/modulo to Z3's floor division/modulo, and solver state corruption from permanently adding shift constraints. Minor improvements were also suggested to resolve an implicit double-to-unsigned conversion warning and fix a typo.

Important

The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.

solver = CreateSolver(*other_.ctx);
// 2. copy the context
// the context is a shared_ptr, we can just copy the pointer
ctx = other_.ctx;

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Assigning to ctx (which is a thread_local static variable) inside CopyFrom is extremely dangerous. It overwrites the thread-local static context for the entire current thread with the source thread's context. If other_ was created on a different thread, this violates Z3's thread-safety rules (sharing a context across threads) and will cause data races or crashes. Even on the same thread, it is redundant. Since ctx is thread-local static, we should not reassign it in CopyFrom.

    // Do not copy the thread-local static context pointer to avoid cross-thread safety issues.

Comment on lines +133 to +137
} else {
auto min_val = Downcast<IntImm>(min_value(dtype))->value;
auto max_val = Downcast<IntImm>(max_value(dtype))->value;
solver.add(ctx->int_val(min_val) <= e && e <= ctx->int_val(max_val));
}

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Downcasting min_value(dtype) and max_value(dtype) to IntImm without checking if dtype is actually an integer or boolean type will cause a compiler crash when encountering variables or loads of other types (like float32 or handle). For non-integer/non-boolean types, min_value returns a FloatImm or other non-IntImm expression, causing Downcast<IntImm> to fail. We should guard this block to only apply to integer and boolean types.

      } else if (dtype.is_int() || dtype.is_uint()) {\n        auto min_val = Downcast<IntImm>(min_value(dtype))->value;\n        auto max_val = Downcast<IntImm>(max_value(dtype))->value;\n        solver.add(ctx->int_val(min_val) <= e && e <= ctx->int_val(max_val));\n      }

Comment on lines +753 to +756
z3::expr VisitExprDefault_(const Object* op) override {
LOG(FATAL) << "Z3Prover only support integers, but got " << op->GetTypeKey() << ".";
TVM_FFI_UNREACHABLE();
}

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Using LOG(FATAL) to crash the compiler when encountering an unsupported node type is unsafe for a fallback analyzer. During optimization passes, the analyzer can be called on arbitrary expressions. It should fail gracefully by returning a free variable (which cannot be proven true/false), allowing the analyzer to return false instead of crashing the entire process.

  z3::expr VisitExprDefault_(const Object* op) override {\n    return Create(static_cast<const PrimExprNode*>(op));\n  }

Comment on lines +419 to +421
auto result = solver.check();
if (result != z3::sat) {
break; // No more solutions

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

If solver.check() returns z3::unknown (e.g., due to timeout or resource limit), the loop simply breaks and returns the partial count found so far. This is incorrect and misleading, as the caller might assume the count is complete. It should return a negative error code (like -1 or -3) to indicate incomplete/unknown results.

      auto result = solver.check();\n      if (result == z3::unknown) {\n        solver.pop();\n        solver.set(\"model\", false);\n        return -1;\n      } else if (result != z3::sat) {\n        break;  // No more solutions\n      }

Comment on lines +617 to +622
z3::expr VisitExpr_(const DivNode* op) override {
return VisitArith(z3::operator/, op, op->a, op->b);
}
z3::expr VisitExpr_(const ModNode* op) override {
return VisitArith(z3::operator%, op, op->a, op->b);
}

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

DivNode and ModNode represent truncated division and modulo in TVM, but they are currently translated directly to Z3's / and % operators, which represent floor division and modulo. This leads to incorrect proofs for negative numbers. They should be translated using a correct truncated division/modulo formula.

  z3::expr VisitExpr_(const DivNode* op) override {\n    if (IsValidDType(op->a->dtype) && IsValidDType(op->b->dtype)) {\n      z3::expr a = VisitInt(op->a);\n      z3::expr b = VisitInt(op->b);\n      return z3::ite(a >= 0,\n                     z3::ite(b >= 0, a / b, -(a / -b)),\n                     z3::ite(b >= 0, -(-a / b), -a / -b));\n    } else {\n      return Create(op);\n    }\n  }\n  z3::expr VisitExpr_(const ModNode* op) override {\n    if (IsValidDType(op->a->dtype) && IsValidDType(op->b->dtype)) {\n      z3::expr a = VisitInt(op->a);\n      z3::expr b = VisitInt(op->b);\n      z3::expr div_res = z3::ite(a >= 0,\n                                 z3::ite(b >= 0, a / b, -(a / -b)),\n                                 z3::ite(b >= 0, -(-a / b), -a / -b));\n      return a - div_res * b;\n    } else {\n      return Create(op);\n    }\n  }

Comment on lines +728 to +749
// For shift operations, we need to ensure the shift amount is non-negative
// and within reasonable bounds
z3::expr a_expr = VisitInt(a);
z3::expr b_expr = VisitInt(b);

// Add constraint that shift amount should be non-negative
// This is a common assumption in many programming languages
solver.add(b_expr >= 0);

// Also limit shift amount to avoid unrealistic large shifts
// We'll limit to 64 bits (reasonable for most use cases)
solver.add(b_expr < 64);

unsigned bit_width = std::max(a.dtype().bits(), b.dtype().bits());
z3::expr a_bv = z3::int2bv(bit_width, a_expr);
z3::expr b_bv = z3::int2bv(bit_width, b_expr);

// Perform the shift in bit-vector domain, then cast back to int.
z3::expr result_bv = op_func(a_bv, b_bv);
return z3::bv2int(result_bv, true);
} else {
return Create(op);

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Permanently adding b_expr >= 0 and b_expr < 64 to the global solver during expression translation (which happens inside CanProve without a push/pop) corrupts the solver state for all subsequent queries. Since Z3's bit-vector shift operations already handle out-of-bounds shift amounts correctly, these constraints are unnecessary and should be removed.

    if (IsValidDType(a->dtype) && IsValidDType(b->dtype)) {\n      z3::expr a_expr = VisitInt(a);\n      z3::expr b_expr = VisitInt(b);\n\n      unsigned bit_width = std::max(a.dtype().bits(), b.dtype().bits());\n      z3::expr a_bv = z3::int2bv(bit_width, a_expr);\n      z3::expr b_bv = z3::int2bv(bit_width, b_expr);\n\n      // Perform the shift in bit-vector domain, then cast back to int.\n      z3::expr result_bv = op_func(a_bv, b_bv);\n      return z3::bv2int(result_bv, true);\n    } else {

// Z3's implementation of timeout, when setting timeout T ms, it will stop at T - 1 ms
// SetTimeoutMs(5);
// use rlimit, not timeout to ensure determinstic behavior
SetRLimit(1e4);

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Using a double literal 1e4 for SetRLimit causes an implicit double-to-unsigned conversion. It is better to use an integer literal like 10000 or 10000U to avoid any compiler warnings or conversion overhead.

    SetRLimit(10000U);

/// @brief Check trivil bad cases, return true if the expr is a bad case
/// Z3 prover may take a long time to initialize (at least 200us),
/// This optimization can speedup 30% of the test cases in our unit tests
bool CheckTrivilBadCases(const PrimExpr& expr) {

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Typo in function name: 'Trivil' should be 'Trivial'.

  bool CheckTrivialBadCases(const PrimExpr& expr) {

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant