Skip to content

Published V0.2.6 and reached a milestone #65

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
ASKabalan opened this issue Mar 1, 2025 · 0 comments
Open

Published V0.2.6 and reached a milestone #65

ASKabalan opened this issue Mar 1, 2025 · 0 comments

Comments

@ASKabalan
Copy link
Collaborator

@EiffL

I updated jaxdecomp to v0.2.6.

This is the most complete version yet.
Everything is supported (vmaping jacfwd, and jacrev) and exhaustively tested.
The errors that users might get are well documented and also caught via RuntimeErrors with clear messages instructing the users to check my documentation.

Everything works on jax.numpy arrays directly, there is no need for any Pytree or static sharding annotation.
(so it is a drop replacement for jnp.fft.fftn)

it uses custom_partitionning and not shard_map (I was able to make every thing work)

The only thing missing is the equivalent of rfft .. however this is a bit challenging because one of the axis will always be odd as in

(16 16 16) => rfft => (16 16 9 )

I think we can still do it with the non contiguous FFT implementation (which is properly implemented using cuDecomp or the JAX backend), however for now I am going to leave this for later as it is not crucial for our work with jaxpm

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

1 participant