The difference is really the level at which you are calling functions on the GPU. Say you have a function `f(x,y,z) = x .+ y .* sin.(z)`. If CUDA (simplify here, the paper does this for Intel OneAPI, Metal, IPUs, and AMD GPUs simultaneously but it's basically the same), then at some point you need to be calling some kernel function, a CUDA-compiled .ptx function which is then operated on over all of the inputs. One way to parallelize this is to have a set of primitive functions, `x .+ y`, `x .* y`, `sin.(x)`, and then decompose the execution into those kernels: first call sin, then call multiply, then call plus. The other way to do this is to on-demand build a specialized .ptx kernel for the function `f` and call that. Machine learning libraries do the former approach, but we demonstrate here that the latter is much better in this scenario because the call overhead to kernels is non-trivial and this ends up slowing down the process. If there's a tl;dr for the paper it's this, and then scale this approach to all GPU architectures from one codebase.
Now I'll simultaneously say that the choice machine learning libraries are making here is not stupid. You may look at this example and go "no duh call 1 kernel instead of 3", but you never want to over optimize. For the domain that ML libraries are designed for, these kernel calls are typically things like large matrix multiplications (that's the core of any deep neural network, with a few things around it). These kinds of operations are O(n^3) or O(n^2) on very large arrays. With that amount of compute to do on the memory, the overhead cost can go to nearly zero. Thus for the use case targeted by ML libraries, approaching the design of the GPU library as "just make enough optimized kernels" is a good design. For example, it was counted in 2021 that PyTorch had about 2,000 such kernels (https://dev-discuss.pytorch.org/t/where-do-the-2000-pytorch-...). Sit down, optimize the CUDA kernels, then make the high level code call the most appropriate one. That's a good design if the kernels are expensive enough, like in deep learning.
While Jax has a few other things going on, both the PyTorch and Jax vmap parallelism approach are effectively high level tools to shove larger arrays more nicely into such existing kernels. For example, one optimization that vmap does is fuse matrix-vector multiplications into matrix multiplications, i.e. Av1 + Av2 -> A*[v1;v2]. The purpose is to still use a small set of primitives and shove as big of array operations as you can into it.
However, that is not a good idea in all domains. In ODE solvers, you have lots of control flow and O(n) operations. This can make that "negligible" overhead very not negligible, and thus one needs to design the parallelism very differently in order to not run into the performance issues that one would hit with the "small kernel array based approach". The better approach in this domain (as demonstrated in the paper) is to build completely new kernels of the functions you're trying to compute, i.e. build a CUDA code and .ptx kernel for f directly, compile that, and do the one call. This has some downsides of course, as this kernel is effectively unable to be reused for other things, which then means that the you need to be able to do this kernel generation automatically for it to be useful at a package level.
In other words, domain-specific languages optimize to their respective domain of choice, but that may be leaving performance on the table for use cases outside of their directly targeted audience.
Now I'll simultaneously say that the choice machine learning libraries are making here is not stupid. You may look at this example and go "no duh call 1 kernel instead of 3", but you never want to over optimize. For the domain that ML libraries are designed for, these kernel calls are typically things like large matrix multiplications (that's the core of any deep neural network, with a few things around it). These kinds of operations are O(n^3) or O(n^2) on very large arrays. With that amount of compute to do on the memory, the overhead cost can go to nearly zero. Thus for the use case targeted by ML libraries, approaching the design of the GPU library as "just make enough optimized kernels" is a good design. For example, it was counted in 2021 that PyTorch had about 2,000 such kernels (https://dev-discuss.pytorch.org/t/where-do-the-2000-pytorch-...). Sit down, optimize the CUDA kernels, then make the high level code call the most appropriate one. That's a good design if the kernels are expensive enough, like in deep learning.
While Jax has a few other things going on, both the PyTorch and Jax vmap parallelism approach are effectively high level tools to shove larger arrays more nicely into such existing kernels. For example, one optimization that vmap does is fuse matrix-vector multiplications into matrix multiplications, i.e. Av1 + Av2 -> A*[v1;v2]. The purpose is to still use a small set of primitives and shove as big of array operations as you can into it.
However, that is not a good idea in all domains. In ODE solvers, you have lots of control flow and O(n) operations. This can make that "negligible" overhead very not negligible, and thus one needs to design the parallelism very differently in order to not run into the performance issues that one would hit with the "small kernel array based approach". The better approach in this domain (as demonstrated in the paper) is to build completely new kernels of the functions you're trying to compute, i.e. build a CUDA code and .ptx kernel for f directly, compile that, and do the one call. This has some downsides of course, as this kernel is effectively unable to be reused for other things, which then means that the you need to be able to do this kernel generation automatically for it to be useful at a package level.
In other words, domain-specific languages optimize to their respective domain of choice, but that may be leaving performance on the table for use cases outside of their directly targeted audience.