Also, calculating GPU costs is getting quite nuanced, with a wide range of prices (https://cloud-gpus.com/) and other variables that makes it harder to do apples-to-apples comparison.
I think you slightly misunderstood, and I wasn't clear enough—sorry! It's not a 30-70% speedup; it's 30-70% more cost-efficient. This is mainly due to non-NVIDIA chipsets (e.g., Google TPU) being cheaper, with some additional efficiency gains from JAX being more closely integrated with the XLA architecture.
No, we haven't run our JAX + XLA on NVIDIA chipsets yet. I'm not sure if NVIDIA has good XLA backend support.
At the bottom, it shows the calculations around the 30% cost efficiency of TPU vs GPU.
Our range of 30-70% is based on some numbers we collected from running fine-tuning runs on TPU and comparing them to similar runs on NVIDIA (though not using our code but other OSS libraries).
It would be a lot more convincing if you actually ran it yourself and did a proper apples to apples comparison, especially considering that’s the whole idea behind your project.
It's also comparing prices on google cloud, which has its own markup, a lot more expensive than say runpod. Runpod is $1.64/hr for the A100 on secure cloud while the A100 on Google is $4.44/hr. A lot more expensive... yeah. So in that context a 30% price beat is actually a huge loss overall.
Also, calculating GPU costs is getting quite nuanced, with a wide range of prices (https://cloud-gpus.com/) and other variables that makes it harder to do apples-to-apples comparison.