Hacker Newsnew | past | comments | ask | show | jobs | submitlogin

Sure, so to see how these things can be learned, we should be a little more precise about how they work.

Each token is a vector, and from that vector we compute three things - a query, a key and a value. Each of these is typically computed by multiplying the token's vector by a matrix (aka a linear projection). It's the values in these matrices that we need to learn.

When performing an attention step, for a given token we compare its "query" to every token's "key" (including it's own key - a token can attend to itself). This gives us a score for how important we think that key is. We normalize those scores to sum to one (typically via a softmax operation). Essentially, we have one "unit" of attention, and we're going to spread it across all the tokens. Some we will pay a lot of attention to, and others very little.

But what does it mean to pay a lot of or a little attention to other tokens? At the end of this whole procedure, we're going to arrive at a new vector that represents our new understanding/meaning for the token we're working on. This vector will be computed as a weighted sum of the values from all the tokens we're attending to. The weights are our attention scores (determined by the query-key similarity scores).

So as a simple example, suppose I have three tokens, A B and C, and let's focus on the attention operation for A. Say A's query vector is [1 2 -1]. A's key vector is [3 -1 0], B's key vector is [3 -1 -1] and C's key vector is [0 1 -3]. This gives us raw attention scores of 1 for A (attending to itself), 4 for B, and 5 for C. Rather than take a messy softmax, let's just normalize these to 0.1, 0.4 and 0.5 for simplicity.

Now that we have our attention weights, we also need to know each token's value. Let's say they are [1 0 1] for A, [-1 2 0] for B, and [1 1 1] for C. So our final output for this attention step will be 0.1 * [1 0 1] + 0.4 * [-1 2 0] + 0.5 * [1 1 1]. This gives us [0.2 1.3 0.6] (assuming I eyeballed the math correctly), this will be our new representation of A for the next step. (in practice there are some additional network layers that do more processing).

Okay, so how can we learn any of the matrices that go from a token vector to a query, a key and a value? The important thing is that all of this is just addition and multiplication - it's all nicely differentiable. And because the attention is "soft" (meaning we always attend at least a little bit to everything, as opposed to "hard" attention where we ignore some items entirely), we can even compute gradients through the attention scores.

Put a simpler way, I can ask "if I had included a bit more of A's value and a bit less of B's value, would my final output have been closer to the target?". To include a bit less of B, I need to make A's query and B's key a little further apart (lower dot product). And to make them a little further apart, I need to adjust the numbers in the matrices that produces them. Similarly, I can ask "if C's value had been a little larger in the first slot and a little smaller in the third, would my final output have been closer to the target?", and adjust the value matrix in the same way. Even if the attention to another token is very low, there's at least a small sliver of contribution from it, so we can still learn that we would've been better off having more (or even less) of it.

Learning the initial embeddings (the vectors that represent each word in the vocabulary, before any processing) is done in the same way - you trace all the way back through the network and ask "if the embedding for the word 'bank' had been a little more like this would my answer have been closer?", and adjust accordingly.

Understanding what exactly the queries keys and values represent is often very difficult - sometimes we can look and see which words attend to which other words and make up a convincing story. "Oh, in this layer, the verb is attending to the corresponding subject" or whatever. But in practice, the real meaning of a particular internal representation is going to be very fuzzy and not have a single clear concept behind it.

There is no explicit guidance to the network like "you should attend to this word because it's relevant to this other word." The only guidance is what the correct final output is (usually for LLMs the training task is to predict the next word, but it could be something else). And then the training algorithm adjusts all the parameters, including the embeddings, the QKV matrices, and all the other weights of the network, in whatever direction would make that correct output more likely.



This was an excellent explanation, thank you for taking the time to write it out!




Guidelines | FAQ | Lists | API | Security | Legal | Apply to YC | Contact

Search: