Hacker Newsnew | past | comments | ask | show | jobs | submitlogin

>I also like that jax.jit forces you to write "functional" functions free of side effects or inplace array updates. It might feel weird at first (and not every algorithm is suited for this style) but ultimately it leads to clearer and faster code.

It's not weird. It's actually the most natural way of doing things for me. You just write down your math equations as JAX and you're done.



> You just write down your math equations as JAX and you're done.

It's natural when your basic unit is a whole vector (tensor), manipulated by some linear algebra expression. It's less natural if your basic unit is an element of a vector.

If you're solving sudoku, for example, the obvious 'update' is in-place.

In-place updates are also often the right answer for performance reasons, such as writing the output of a .map() operation directly to the destination tensor. Jax leans heavily on compile-time optimizations to turn the mathematically-nice code into computer-nice code, so the delta between eager-Jax and compiled-Jax is much larger than the delta between eager-Pytorch and compiled-Pytorch.




Guidelines | FAQ | Lists | API | Security | Legal | Apply to YC | Contact

Search: