[Relax][NN] Use int64 for RoPE apply flag#19430
Conversation
There was a problem hiding this comment.
Code Review
This pull request updates the data type of the apply_rope parameter from int32 to int64 in the Python frontend and changes the corresponding cast to int64_t in the C++ runtime. These changes ensure consistency in integer precision across the codebase. I have no feedback to provide.
|
Downstream impact: This fixes MLC-LLM export paths that instantiate the RoPE PrimFunc through |
cchung100m
left a comment
There was a problem hiding this comment.
LGTM, thanks to @xthomaswang 😄
This patch aligns the dtype of the
apply_ropeflag used byllama_rope_with_position_mapwith the host-side value passed throughRelax call_tir.
Previously the PrimFunc declared
apply_ropeasT.int32, while thecaller-side scalar value is represented as an int64 Relax PrimValue /
ShapeExpr value. This caused Relax well-formed analysis to reject the IR
with:
The mismatch can be reproduced through downstream
nn.Module.export_tvmpaths such as MLC-LLM
convert_weight/compile.This change updates:
llama_rope_with_position_map:apply_rope: T.int32->T.int64PagedKVCache: pass the split-rotary flag asint64_t