-
Notifications
You must be signed in to change notification settings - Fork 5
Pull requests: jax-ml/jax-tpu-embedding
Author
Label
Projects
Milestones
Reviews
Assignee
Sort
Pull requests list
[JAX] Cache the aval_to_ir_type mapping during MLIR lowering.
#758
opened Apr 29, 2026 by
copybara-service
Bot
Loading…
Add support for CustomOptimizer in JAX TPU embedding.
#757
opened Apr 24, 2026 by
copybara-service
Bot
Loading…
This updates the rules_ml_toolchain dependency to commit 99c43dfe995a0e81c767d5b6d686191992672fe6.
#752
opened Apr 13, 2026 by
copybara-service
Bot
Loading…
Require shapes passed to ShapedArray to be canonical.
#745
opened Apr 6, 2026 by
copybara-service
Bot
Loading…
Flax Linen: Support learning rate schedules
#742
opened Apr 2, 2026 by
copybara-service
Bot
Loading…
[JAX SC] Deprecate
has_leading_dimension in embedding preprocessing functions.
#716
opened Feb 27, 2026 by
copybara-service
Bot
Loading…
[JAX SC] (2/n) Deprecate pmap/has_leading_dimension: Core Logic Tests
#715
opened Feb 27, 2026 by
copybara-service
Bot
Loading…
ProTip!
Type g p on any issue or pull request to go back to the pull request listing page.