Attention from Scratch in Julia
- Overview
- Embeddings
- Attention
- Minimal Architecture
- Forward pass
- Backwards pass
- Sentiment data
- Training
- Results
- Full Code
- References
Overview
This post explores the attention mechanism by building it up from scratch and applying it to a crude and rudimentary sentiment analysis example.
Embeddings
We start with a sentence. A sentence with $n$ words is essentially a sequence of words $w_1,w_2…w_n$ . To process them numerically, we need to map words into a vector representation i.e. $word\ w_i \to vector\ x_i$ To do so, we can use pretrained word embeddings out there such as GloVe. This is represented by
\(\begin{equation}E(w_i) = x_i \tag{1}\label{eq:embeddings}\end{equation}\) where $E$ is an embeddings transformation.
First import an embeddings package and load in GloVe embeddings.
Next we write 2 helper functions - get_embeddings that looks up a word in embtable and returns the embeddings (\ref{eq:embeddings}) and word_tokeniser that splits up a sentence into a vector of words.
We can try it out in the julia REPL.
Attention
We have 3 learnable weight parameters - queries $\boldsymbol{Q}$, keys $\boldsymbol{K}$, values $\boldsymbol{V}$ each with dimensions $(d_q+1)x(d_q+1), (d_k+1)x(d_k+1),(d_v+1)x(d_v+1)$ respectively. Let $\boldsymbol{x}$ be a $(d_q+1)xn$ dimensional matrix Next apply a linear transformation of the embeddings $\boldsymbol{x}$ with all 3 weight matrices giving us $\boldsymbol{q},\boldsymbol{k},\boldsymbol{v}$. \(\begin{equation} \boldsymbol{Q}\boldsymbol{x} = \boldsymbol{q}\\ \boldsymbol{K}\boldsymbol{x} = \boldsymbol{k}\\ \boldsymbol{V}\boldsymbol{x} = \boldsymbol{v}\\ \tag{2}\label{eq:qkv} \end{equation}\)
For simplicity sake, let us assume that every word embedding of a sentences is a key, keys = values and each word embedding is a query. This implies that $d_q=d_k$ .
\[Cosine\ Similarity\ \boldsymbol{e} = \frac{\boldsymbol{q}^T\boldsymbol{k}}{\sqrt{d_k}}\tag{3}\label{eq:cosine similarity}\]Note that the cosine similarity has been scaled by $\frac{1}{\sqrt{d_k}}$. To gain an intuition as to why, compare the norm betwen a 2 element vector and a 3 element vector. e.g. \(\begin{aligned} &\sqrt{\begin{bmatrix}2 \\ 2\end{bmatrix} \bullet \begin{bmatrix}2 \\ 2\end{bmatrix}}=\sqrt{2^2+2^2} &=2\sqrt{2} \end{aligned}\) \(\begin{aligned} &\sqrt{\begin{bmatrix}2 \\ 2 \\ 2\end{bmatrix} \bullet \begin{bmatrix}2 \\2 \\ 2\end{bmatrix}}=\sqrt{2^2+2^2+2^2} &=2\sqrt{3} \end{aligned}\) It is clear that as the dimensions $d$ grow,the norm scales by $\sqrt{d}$.
\[Attention\ Weights \ \boldsymbol{\alpha} = softmax(\frac{\boldsymbol{q}^T\boldsymbol{k}}{\sqrt{d_k}})\tag{4}\label{eq:attention weights}\] \[Attention\ \boldsymbol{z}= softmax(\frac{\boldsymbol{q}^T\boldsymbol{k}}{\sqrt{d_k}})\boldsymbol{v} \tag{5}\label{eq:attention}\]Let us define 2 helper functions LinearTransform and softmax as below.
Then define a function called Attention as below.
The first line creates a function called Attention, which takes in the word embeddings $\boldsymbol{x}$ and trainable parameters $\boldsymbol{Q,Qb,K,Kb,V,Vb}$ as arguments. Note that I have chosen to separate the Query $\boldsymbol{Q}$, Key $\boldsymbol{K}$, Value $\boldsymbol{V}$ weight matrices from their biases $\boldsymbol{Qb,Kb,Vb}$ for clarity. So $\boldsymbol{Q},\boldsymbol{K},\boldsymbol{V},\boldsymbol{x}$ are $(d_qxd_q),(d_qxd_q),(d_qxd_q),(d_qxn)$ dimensions.
The next 3 lines applies a linear transformation to the word embeddings as per equation (\ref{eq:qkv}).
This line calls on a function AttentionWeights which returns the attention weight matrix $\alpha$ as per equation (\ref{eq:attention weights}).
The last line implements equation (\ref{eq:attention}) and returns the context vectors.
The AttentionWeights function firsts computes the unnormalized cosine similarity matrix $e$ between queries and keys in (\ref{eq:attention}), scaled by the square root of key dimensions. Then, we apply the softmax function along each row of the resulting matrix.
Now we have our attention building block! For more detailed explanation for the forward and backpropagation steps for attention, please refer to this post.
Minimal Architecture
To see attention in action, let us consider a minimal functional architecture with attention at the heart of it.
The architecture consists of 4 intermediate layers - Attention layer, Feed forward layer, Pooling layer, Softmax layer. For the input, we pass in a concatenation of word embeddings and the positional embeddings. At the final output layer, we obtain the classification probabilities. For this example, let us assume that we are doing sentiment analysis for a sentence. The sentiment labels are one-hot vectors - positive $[0,0,1]$, negative $[1,0,0]$, neutral $[0,1,0]$. So when we pass in an input sentence through the above architecture, we get the probabilities of whether our input sentence is either positive, negative or neutral.
Forward pass
Let us now create a function Forwardprop for the foward pass through the above architecture. The function takes in a word embedding matrix $\boldsymbol{x}$, the Attention parameters $\boldsymbol{Q,Q_b,K,K_b,V,V_b}$ and the Feedforward parameters $\boldsymbol{W,b}$ as arguments.
The if condition handles a sentence with just a single word by reshaping the 1D word embeddings to 2D. The vcat concatenates the word embeddings $x$ with its position embeddings (index position in the sentence).
We next pass in the arguments to the Attention function as defined in Attention, which returns the queries, keys, values, attention weights and context vectors.
In the third layer, we pass the $z$ context vectors, currently a $d_vxn$ matrix through a feedforward layer. The Feedforward function is shown below.
The feedforward layer transforms the context vectors $z$ to a $3xn$ matrix.
Now, before passing it through to the softmax function, we need to reduce the 3xn matrix to a 3x1 matrix. A simple naive way would be to just average the $n$ columns. Hence, we next pass $f$ through a pooling layer to obtain a 3x1 vector before passing it through a softmax layer. The final values returns the probabilities of the sentiments.
Putting it all together,
Backwards pass
This section dives into the details of computing gradients required for backpropagation.
Cross Entropy
As we have a softmax function at the last layer, a good loss function to use is the cross entropy loss, defined below. Note that mean squared error can work poorly with softmax, explained here.
Working our way backwards, we first compute the derivative of loss with respect to the softmax, which is just the prediction p - the training label y.
\[\begin{aligned} \frac{\partial{L}}{\partial{\boldsymbol{p}}}=\boldsymbol{p}-\boldsymbol{y}\\ \end{aligned}\]Pooling
Recall that for the pooling layer, we averaged across all columns as below.
\[\begin{aligned} \boldsymbol{f}&=\begin{bmatrix}\boldsymbol{f_1} \ \boldsymbol{f_2} \ \boldsymbol{f_3} \ \ldots \boldsymbol{f_n} \end{bmatrix}\\ \boldsymbol{p}&=\frac{1}{n}\sum_{i=1}^{n}\boldsymbol{p_i}\\ &=\frac{1}{n}(\boldsymbol{f_1} + \boldsymbol{f_2} + \boldsymbol{f_3} + \ldots + \boldsymbol{f_n}) \\ \end{aligned}\tag{6}\label{eq:pooling}\]The derivative of (\ref{eq:pooling}) gives us the gradients
\[\begin{aligned} \frac{\partial{p}}{\partial{\boldsymbol{f}}} &=\frac{1}{n}\begin{bmatrix}1 \ 1 \ \ldots 1\\ \end{bmatrix} \end{aligned}\] \[\begin{equation} \frac{\partial{L}}{\partial{\boldsymbol{f}}}=\frac{\partial{L}}{\partial{\boldsymbol{p}}}\frac{\partial{\boldsymbol{p}}}{\partial{\boldsymbol{f}}}\\\tag{7}\label{eq:dLdf} \end{equation}\]Feedforward
The local gradients of $\boldsymbol{f}$ are \(\begin{equation}\frac{\partial{\boldsymbol{f}}}{\partial{\boldsymbol{z}}}=\boldsymbol{W}\\ \tag{8}\label{eq:dfdz}\end{equation}\)
\[\begin{equation} \frac{\partial{\boldsymbol{f}}}{\partial{\boldsymbol{W}}}=\boldsymbol{z}^T\\ \tag{9}\label{eq:dfdW} \end{equation}\]Using (\ref{eq:dLdf}) & (\ref{eq:dfdW}) \(\begin{equation} \frac{\partial{L}}{\partial{\boldsymbol{W}}}=\frac{\partial{L}}{\partial{\boldsymbol{f}}}\frac{\partial{\boldsymbol{f}}}{\partial{\boldsymbol{W}}}\\ \tag{10}\label{eq:dLdW} \end{equation}\) \(\begin{equation} \frac{\partial{L}}{\partial{\boldsymbol{b}}}=sumcol(\frac{\partial{L}}{\partial{\boldsymbol{f}}})\\ \tag{11}\label{eq:dLdb} \end{equation}\) We will need (\ref{eq:dLdW}) and (\ref{eq:dLdb}) later during the parameter update step.
Context vector
Using (\ref{eq:dLdf}) & (\ref{eq:dfdz}) \(\begin{equation} \frac{\partial{L}}{\partial{\boldsymbol{z}}}=\frac{\partial{L}}{\partial{\boldsymbol{f}}}\frac{\partial{\boldsymbol{f}}}{\partial{\boldsymbol{z}}}\\ \tag{12}\label{eq:dLdz} \end{equation}\)
The transpose operations ensures that the shape of the gradient matrix matches that of $\boldsymbol{z}$.
Attention
The below code computes the gradients for the attention layer. For more detailed explanations, please refer to this post.
Update step
The first block of code below initializes new parameters to the current ones. Then, using the gradients computed previously for attention and equations (\ref{eq:dLdW}) and (\ref{eq:dLdb}), the new $Q_{new},Qb_{new},K_{new},Kb_{new},V_{new},Vb_{new},W_{new},b_{new}$ can be updated using SGD.
backprop code
The full backprop function is as below.
Sentiment data
For our crude sentiment example, let us create a small sentiment dataset with first column containing the sentiment and second column containing the text. The contents for this “small.csv” are
We next import in the packages CSV and DataFrames to allow us to load in the csv file.
To handle the case where the word in a sentence is not in our embeddings dictionary, for simplicity, we will just drop it using the following remove_nid function.
Training
For training, we have the below train function which
- calls on forwardprop for the forward propagate step
- computes the Cross Entropy Loss with CrossEntropyLoss(p,y)
- backpropagates and returns the updated parameters with backprop(x,y,train_params…,q,k,v,α,z,p)
Note that we absorbed the train parameters with the julia splat (…) notation to keep the function readable.
The below code
- initializes small and random training parameters
- creates a dictionary that maps sentiments (positive, negative, neutral) to one-hot vectors
- processes our input sentences
- trains our minimal architecture
- returns loss
To see the queries, keys, values being trained, we will deliberately not update the feedforward network parameters by commenting out the update steps for $\boldsymbol{W,b}$ in the backprop function.
Results
We are interested to see the resulting attention weights after training. To help us visualize that, we first import the julia Plots package.
We then write a function evaluate_model which takes in a sentence sen, removes words not in our dictionary, transforms each word into embeddings and then forward propagating them to obtain the attention weights. With the attention weights, we can now plot a heatmap with the Plots package.
When we run
We can see from the below heatmap that the attention weights are “focusing” in on certain words.
Full Code
Putting it all together
And that is attention from scratch!
References
- http://neuralnetworksanddeeplearning.com/chap1.html
- https://cs231n.github.io/optimization-2/
- https://web.stanford.edu/class/archive/cs/cs224n/cs224n.1234/slides/cs224n-2023-lecture08-transformers.pdf
- https://kcin96.github.io/notes/ml/2024/01/18/backpropagation.html
- https://kcin96.github.io/notes/ml/2023/12/29/why-mean-squared-loss-works-poorly-with-softmax.html
- Attention Is All You Need
- Jeffrey Pennington, Richard Socher, and Christopher D. Manning. 2014. GloVe: Global Vectors for Word Representation.