Interesting: Given an input x to a layer f(x)=Wσ(x), where σ is an activation function and W is a weight matrix, the authors define the layer's "neural feature matrix" (NFM) as WᵀW, and show that throughout training, it remains proportional to the average outer product of the layer's gradients, i.e., WᵀW ∝ mean(∇f(x)∇f(x)ᵀ), with the mean computed over all samples in the training data. The authors posit that layers learn to up-weight features strongly related to model output via the NFM.
The authors do interesting things with the NFM, including explaining why pruning should even be possible and why we see grokking during learning. They also train a kernel machine iteratively, at each step alternating between (1) fitting the model's kernel matrix to the data and (2) computing the average gradient outer product of the model and replacing the kernel matrix with it. The motivation is to induce the kernel machine to "learn to identify features." The approach seems to work well. The authors' kernel machine outperforms all previous approaches on a common tabular data benchmark.
This NFM is a curious quantity . It has the flavor of a metric on the space of inputs to that layer. However, the fact that W’W remains proportional to DfDf’ seems to be an obvious consequence of the very form of f… since Df is itself Ds’WW’Ds, then this should be expected under some assumptions (perhaps mild) on the statistics of Ds, no?
It has more than just the flavor of a metric, it’s exactly a metric because any “square” matrix (M’M) is positive definite (since x’M’Mx=y’y=<y,y>, where y=Mx). It could be interpreted as a metric which only cares about distances along dimensions that the NN cares about.
I agree that the proportionality seems a little obvious, I think what’s most interesting is what they do with the quantity, but I’m surprised no one else has tried this.
Yes of course any positive definite matrix can be used as a metric on the corresponding Euclidean space - but that doesn’t mean it’s necessarily useful as a metric. Hence I think it’s useful to distinguish things which could be a metric (in that a metric can be constructed from them), versus things which when applied as a metric actually provide some benefit.
In particular, if we believe the manifold hypothesis, then one should expect a useful metric on features to be local and not static - the quantity W’W clearly does not depend on the inputs to the layer at inference time, and so is static.
How are you defining the utility of a metric here? It’s not clear to me why a locally-varying metric would be necessarily more 'useful' than a global one in the context of the manifold hypothesis.
Moreover, if I’m understanding their argument right then W’W is proportional to an average of the exterior derivative of the manifold representing prediction surface of any given NN layer (averaging with respect to the measure defined by the data generating process). While this averaging by definition leaves some of the local information on the cutting room floor, the result is going to be far more interpretable (because we've discarded all that distracting local data) and I would assume will still retain the large-scale structure of the underlying manifold (outside of some gross edge-cases).
If one thinks of a metric as a “distance measure”, which is to say, how similar is some input x to the “feature” encoded by a layer f(x), and if this feature corresponds to some submanifold of the data, then naturally this manifold will have curvature and the distance measure will do better to account for this curvature. Then generally the metric (in this case, defining a connection on the data manifold) should encode this curvature and therefore is a local quantity. If one chooses a fixed metric, then implicitly the data manifold is being treated as a flat space - like a vector space - which generally it is not. My favorite example for this is the earth, a 2-sphere that is embedded in a higher dimensional space. The correct similarity measure between points is the length of a geodesic connecting those points. If instead one were to just take a flat map (choice of coordinates) and compare their Euclidean distance, it would only be a decent approximation of similarity if the points are already very close. This is like the flat earth fallacy.
But this argument seems analogous to someone saying that average height is less “correct” than the full original data set because every individual’s height is different. In one sense it’s not wrong, but it kind of misses the point of averaging. The full local metric tensor defined on the manifold is going to have the same “complexity” as the manifold itself; it’s a bad way of summarizing a model because it’s not any simpler than the model. Their approach is to average that metric tensor over the region of the manifold swept out by the training data, and they show that this average empirically reflects something meaningful about the underlying response manifold in problems that we’re interested in. Whether or not this average quantity can entirely reproduce that original manifold is kind of irrelevant (and indeed undesirable), the point is that it (a) represents something meaningful about the model and (b) it’s low dimensional enough for a human to reason about. Although globally it will not be accurate to distances along the surface, presumably it is “good enough” to at least first order for much of the support of the training data.
Yes, that sounds right, but it doesn't make the work less worthwhile or less interesting.
The authors do interesting things with the NFM, including explaining why pruning should even be possible and why we see grokking during learning. They also train a kernel machine iteratively, at each step alternating between (1) fitting the model's kernel matrix to the data and (2) computing the average gradient outer product of the model and replacing the kernel matrix with it. The motivation is to induce the kernel machine to "learn to identify features." The approach seems to work well, outperforming all previous approaches on tabular data.
PS. I've updated my comment to add these additional points.
The authors do interesting things with the NFM, including explaining why pruning should even be possible and why we see grokking during learning. They also train a kernel machine iteratively, at each step alternating between (1) fitting the model's kernel matrix to the data and (2) computing the average gradient outer product of the model and replacing the kernel matrix with it. The motivation is to induce the kernel machine to "learn to identify features." The approach seems to work well. The authors' kernel machine outperforms all previous approaches on a common tabular data benchmark.