Support integer types in TIR expression operators#18390
Conversation
Summary of ChangesHello @jikechao, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request introduces a crucial enhancement to the Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request addresses an important issue by adding support for integer types in tvm.tir.exp, preventing an InternalError during LLVM code generation. The approach of casting integer inputs to float64 aligns well with the behavior of libraries like NumPy and PyTorch. I've suggested a small improvement to make the integer type detection more robust. Additionally, for consistency across the API, it would be beneficial to consider applying this same integer-to-float casting logic to other transcendental functions (e.g., log, sin, cos) that likely have similar requirements for floating-point inputs, perhaps in a follow-up pull request.
cbalint13
left a comment
There was a problem hiding this comment.
Idea for future,
Beside this casting to float (simplest) I think we could try using a LUT like implementation assuring that we truly emit integer ISA counter part from llvm (staying in the int domain), so we could also tackle on a good integer softmax() implementation, TVM being a compiler I would expect such capability.
Hi @cbalint13, thanks for your review and suggestion. While LUT is well-suited for softmax due to its normalized output and constrained input range, implementing |
Added tests for 'exp' function to verify output against NumPy.
This PR addresses the issue where
tvm.tir.expdoes not support integer types (e.g., int32, int64), causing an InternalError during LLVM code generation with the message.The issue arises because the
llvm.expintrinsic expects floating-point inputs, but no type conversion is performed for integer inputs.I opened this PR to solve it via type conversion. This change aligns the behavior of
tir.expwith libraries like PyTorch and NumPy, which implicitly convert integer inputs to floating-point types for their exponential functions.Fix #18381