You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
This is the most complete version yet.
Everything is supported (vmaping jacfwd, and jacrev) and exhaustively tested.
The errors that users might get are well documented and also caught via RuntimeErrors with clear messages instructing the users to check my documentation.
Everything works on jax.numpy arrays directly, there is no need for any Pytree or static sharding annotation.
(so it is a drop replacement for jnp.fft.fftn)
it uses custom_partitionning and not shard_map (I was able to make every thing work)
The only thing missing is the equivalent of rfft .. however this is a bit challenging because one of the axis will always be odd as in
(16 16 16) => rfft => (16 16 9 )
I think we can still do it with the non contiguous FFT implementation (which is properly implemented using cuDecomp or the JAX backend), however for now I am going to leave this for later as it is not crucial for our work with jaxpm
The text was updated successfully, but these errors were encountered:
@EiffL
I updated jaxdecomp to v0.2.6.
This is the most complete version yet.
Everything is supported (vmaping jacfwd, and jacrev) and exhaustively tested.
The errors that users might get are well documented and also caught via RuntimeErrors with clear messages instructing the users to check my documentation.
Everything works on jax.numpy arrays directly, there is no need for any Pytree or static sharding annotation.
(so it is a drop replacement for jnp.fft.fftn)
it uses
custom_partitionning
and notshard_map
(I was able to make every thing work)The only thing missing is the equivalent of rfft .. however this is a bit challenging because one of the axis will always be odd as in
(16 16 16) => rfft => (16 16 9 )
I think we can still do it with the non contiguous FFT implementation (which is properly implemented using cuDecomp or the JAX backend), however for now I am going to leave this for later as it is not crucial for our work with jaxpm
The text was updated successfully, but these errors were encountered: