Skip to content

call_jax doesn't take jax config into hashing. #8963

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
zpcore opened this issue Apr 11, 2025 · 0 comments
Open

call_jax doesn't take jax config into hashing. #8963

zpcore opened this issue Apr 11, 2025 · 0 comments
Assignees
Labels
bug Something isn't working tracing Lazy Tensor tracing

Comments

@zpcore
Copy link
Collaborator

zpcore commented Apr 11, 2025

🐛 Bug

call_jax doesn't take jax config into hashing.

Detail

JAX config changes (https://github.com/jax-ml/jax/blob/3864c4f335d1d236d5367264f3885dfce8721d9d/jax/_src/config.py#L254) will not be reflected in the call_jax function argument. However, the config will be embedded in the HLO level (e.g., data precision), which potentially causes computations with different JAX config to reuse the same HLO.

@zpcore zpcore self-assigned this Apr 11, 2025
@ysiraichi ysiraichi added bug Something isn't working tracing Lazy Tensor tracing torchxla2 labels Apr 16, 2025
@qihqi qihqi removed the torchxla2 label Apr 17, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working tracing Lazy Tensor tracing
Projects
None yet
Development

No branches or pull requests

3 participants