you don't transpose it before the matmul, you always have it transposed (i.e., when you print the weights of a linear layer in pytorch, you're actually seeing (A^t)^t and what's stored is A^t.
a polyhedral compiler wouldn't find this either - polyhedral compilation is for finding optimal schedules for loop nests i.e., the order in which independent (wrt dataflow) iterations run. as far as i know you, a transpose can't be expressed in the polyhedral model.
Hmm I thought GCC's polyhedral optimizations had a loop transposition, but it turned out I was remembering an old "-floop-transpose" flag that seems to be only in old Apple GCC to get a SPEC win…
since this post might attract people that are mathematica powerusers: how good are mathematicas optimization routines vs commercial solvers like gurobi or cplex? reason i ask is i'm spinning up project that'll require a good bit of MIP, MILP and i have a mathematica license but i'm considering getting a gurobi license. since i've never tried gurobi i can't compare.
I never heard about Mathematica being used to solve large-scale MIPs. The best solvers in the market are very difficult to beat, since this is such a unique market with very specialized domain knowledge. So, while I guess it is possible to solve small to medium sized problems with Mathematica, I don't believe they're able to compete with some of the best solvers like Gurobi.
these two things have nothing to do with each other. jax doesn't compile numpy, it reimplemnts the api using `ufunc`. in general, every single numerical kernel is always mapped to kind of compiled code.
It does for the user who is familiar with Python and Numpy. With some effort your Python and Numpy code becomes orders of magnitude faster. Telling that those two things have nothing to do with each other is missing the point.
1. this is a thread about cpython. jax is as relevant to users of cpython as CUDA or OpenCL or whatever. jax cannot do absolutely anything with e.g. django.
2. for all intents and purposes all numerical code always runs in a lower-level implementation (C++, CUDA, XLA, whatever). so from that perspective, jax is just a convenient way to get from numerical python (i.e., loops and muls and adds) to kernels.
I didn't claim Jax can accelerate Django, it all depends. A lot of our Python code is/was running partly in cpython and partly in extension modules such as Numpy.
There are many ways to achieve faster Python execution. One is a faster cpython implementation, another is moving cpu intensive parts of the code to extension modules (such as Numpy). Yet another is to jit compile Python (and Numpy) code to run on accelerators.
given who you are (googling your name) i'm surprised that you would say this. jax does not jit compile python in any sense of the word `Python`. jax is a tracing mechanism for a very particular set of "programs" specified using python; i put programs in quotes because it's not like you could even use it to trace through `if __name__ == "__main__"` since it doesn't know (and doesn't care) anything about python namespaces. it's right there in the first sentence of the description:
>JAX is Autograd and XLA
autograd for tracing and building the tape (wengert list) and xla for the backend (i.e., actual kernels). there is no sense in which jax will ever play a role in something like faster hash tables or more efficient loads/stores or virtual function calls.
in fact it doesn't even jit in the conventional understanding of jit, since there is no machine code that gets generated anew based on code paths (it simply picks different kernels and such that have already been compiled). not that i fault you for this substitution since everyone in ML does this (pytorch claims to jit as well).
I agree with you that making CPython faster, or rewriting CPython entirely into Cinder are more general purpose ways to make Python faster, while Jax is much more specific and limited and require to transform your Python code, often manually.
You miss my point that all of those efforts are making slow Python code run faster. So claiming that 'these two things have nothing to do with each other' is wrong, because they share 'making Python code run faster'.
Some of that involves making cpython faster, some of that means moving execution into c (numpy is mentioned in that PDF) and some involves jit and moving execution onto GPU or TPU (for example using XLA). The common part is 'making Python code run faster'. Some of that is automatic, some requires some manual effort.
Jax can jit some Python functions, but it cannot efficiently jit everything. That is what I meant by decoration and 'some effort'. For example replacing IF conditions by np.where etc. See also https://jax.readthedocs.io/en/latest/jax-101/02-jitting.html
My background is in physics simulation, and I advise the Brax team, basically accelerating a physics engine written in Python run on accelerators, see https://github.com/google/brax
The entire physics step, including collision detection and physics solver, is jit compiled.
there is not a single org anywhere in the world that uses pure python to do numerics. kids do that during their first linear algebra or ml class. that's it.
>For example replacing IF conditions by np.where etc
i've already addressed this - this is not jit compilation.
Many orgs use Python+Numpy, and that can be made faster using Jax
>> this is not jit compilation.
I disagree. Jax jit uses XLA, and XLA is a JIT compiler. An XLA graph is created during the runtime of the host program, and JIT-compiled to native code for the CPU, GPU or TPU.
correct and non-controversial
> An enormous number of people and products [use CoreML on Apple platforms]
non-sequitur
EDIT: i see people are not aware of
https://en.wikipedia.org/wiki/Simpson%27s_paradox