Attention from Scratch in Julia
- Overview
- Embeddings
- Attention
- Minimal Architecture
- Forward pass
- Backwards pass
- Sentiment data
- Training
- Results
- Full Code
- References
Overview
This post explores the attention mechanism by building it up from scratch and applying it to a crude and rudimentary sentiment analysis example.
Embeddings
We start with a sentence. A sentence with $n$ words is essentially a sequence of words $w_1,w_2…w_n$ . To process them numerically, we need to map words into a vector representation i.e. $word\ w_i \to vector\ x_i$ To do so, we can use pretrained word embeddings out there such as GloVe. This is represented by
\(\begin{equation}E(w_i) = x_i \tag{1}\label{eq:embeddings}\end{equation}\) where $E$ is an embeddings transformation.
First import an embeddings package and load in GloVe embeddings.
using Embeddings
const embtable = load_embeddings(GloVe{:en},1,max_vocab_size=10000)
const get_word_index = Dict(word=>ii for (ii,word) in enumerate(embtable.vocab))
Next we write 2 helper functions - get_embeddings that looks up a word in embtable and returns the embeddings (\ref{eq:embeddings}) and word_tokeniser that splits up a sentence into a vector of words.
#Returns embeddings for word
function get_embeddings(word)
return embtable.embeddings[:,get_word_index[word]]
end
#Splits sentence into a vector of words
function word_tokeniser(sentence)
return split(sentence," ")
end
We can try it out in the julia REPL.
julia> get_embeddings("red")
50-element Vector{Float32}:
-0.12878
0.8798
-0.60694
0.12934
0.5868
-0.038246
-1.0408
-0.52881
-0.29563
-0.72567
0.21189
0.17112
0.19173
⋮
0.050825
-0.20362
0.13695
0.26686
-0.19461
-0.75482
1.0303
-0.057467
-0.32327
-0.7712
-0.16764
-0.73835
julia> word_tokeniser("This is a sentence")
4-element Vector{SubString{String}}:
"This"
"is"
"a"
"sentence"
Attention
We have 3 learnable weight parameters - queries $\boldsymbol{Q}$, keys $\boldsymbol{K}$, values $\boldsymbol{V}$ each with dimensions $(d_q+1)x(d_q+1), (d_k+1)x(d_k+1),(d_v+1)x(d_v+1)$ respectively. Let $\boldsymbol{x}$ be a $(d_q+1)xn$ dimensional matrix Next apply a linear transformation of the embeddings $\boldsymbol{x}$ with all 3 weight matrices giving us $\boldsymbol{q},\boldsymbol{k},\boldsymbol{v}$. \(\begin{equation} \boldsymbol{Q}\boldsymbol{x} = \boldsymbol{q}\\ \boldsymbol{K}\boldsymbol{x} = \boldsymbol{k}\\ \boldsymbol{V}\boldsymbol{x} = \boldsymbol{v}\\ \tag{2}\label{eq:qkv} \end{equation}\)
For simplicity sake, let us assume that every word embedding of a sentences is a key, keys = values and each word embedding is a query. This implies that $d_q=d_k$ .
\[Cosine\ Similarity\ \boldsymbol{e} = \frac{\boldsymbol{q}^T\boldsymbol{k}}{\sqrt{d_k}}\tag{3}\label{eq:cosine similarity}\]Note that the cosine similarity has been scaled by $\frac{1}{\sqrt{d_k}}$. To gain an intuition as to why, compare the norm betwen a 2 element vector and a 3 element vector. e.g. \(\begin{aligned} &\sqrt{\begin{bmatrix}2 \\ 2\end{bmatrix} \bullet \begin{bmatrix}2 \\ 2\end{bmatrix}}=\sqrt{2^2+2^2} &=2\sqrt{2} \end{aligned}\) \(\begin{aligned} &\sqrt{\begin{bmatrix}2 \\ 2 \\ 2\end{bmatrix} \bullet \begin{bmatrix}2 \\2 \\ 2\end{bmatrix}}=\sqrt{2^2+2^2+2^2} &=2\sqrt{3} \end{aligned}\) It is clear that as the dimensions $d$ grow,the norm scales by $\sqrt{d}$.
\[Attention\ Weights \ \boldsymbol{\alpha} = softmax(\frac{\boldsymbol{q}^T\boldsymbol{k}}{\sqrt{d_k}})\tag{4}\label{eq:attention weights}\] \[Attention\ \boldsymbol{z}= softmax(\frac{\boldsymbol{q}^T\boldsymbol{k}}{\sqrt{d_k}})\boldsymbol{v} \tag{5}\label{eq:attention}\]Let us define 2 helper functions LinearTransform and softmax as below.
function LinearTransform(x,W,b)
return W*x.+b
end
function softmax(x)
x = x .- maximum(x)
return exp.(x) ./ sum(exp.(x))
end
Then define a function called Attention as below.
# Attention block
function Attention(x,Q,Qb,K,Kb,V,Vb)
# queries
q = LinearTransform(x,Q,Qb)
# keys
k = LinearTransform(x,K,Kb)
# values
v = LinearTransform(x,V,Vb)
# Attention Weights
α = AttentionWeights(x,q,k,v)
# context vectors
z = v * α'
return q,k,v,α,z
end
The first line creates a function called Attention, which takes in the word embeddings $\boldsymbol{x}$ and trainable parameters $\boldsymbol{Q,Qb,K,Kb,V,Vb}$ as arguments. Note that I have chosen to separate the Query $\boldsymbol{Q}$, Key $\boldsymbol{K}$, Value $\boldsymbol{V}$ weight matrices from their biases $\boldsymbol{Qb,Kb,Vb}$ for clarity. So $\boldsymbol{Q},\boldsymbol{K},\boldsymbol{V},\boldsymbol{x}$ are $(d_qxd_q),(d_qxd_q),(d_qxd_q),(d_qxn)$ dimensions.
# queries
q = LinearTransform(x,Q,Qb)
# keys
k = LinearTransform(x,K,Kb)
# values
v = LinearTransform(x,V,Vb)
The next 3 lines applies a linear transformation to the word embeddings as per equation (\ref{eq:qkv}).
# Attention Weights
α = AttentionWeights(x,q,k,v)
This line calls on a function AttentionWeights which returns the attention weight matrix $\alpha$ as per equation (\ref{eq:attention weights}).
# context vectors
z = v * α'
The last line implements equation (\ref{eq:attention}) and returns the context vectors.
# Return Attention Weights
function AttentionWeights(x,q,k,v)
# compute similarity between queries and keys (with scaling)
e = q'*k/sqrt(length(q))
# initialize attention weight matrix α with zeroes
α = zeros(size(e))
# normalize each similarity row with softmax
for row in 1:size(e)[1]
α[row,:] = softmax(e[row,:])
end
return α
end
The AttentionWeights function firsts computes the unnormalized cosine similarity matrix $e$ between queries and keys in (\ref{eq:attention}), scaled by the square root of key dimensions. Then, we apply the softmax function along each row of the resulting matrix.
Now we have our attention building block! For more detailed explanation for the forward and backpropagation steps for attention, please refer to this post.
Minimal Architecture
To see attention in action, let us consider a minimal functional architecture with attention at the heart of it.
The architecture consists of 4 intermediate layers - Attention layer, Feed forward layer, Pooling layer, Softmax layer. For the input, we pass in a concatenation of word embeddings and the positional embeddings. At the final output layer, we obtain the classification probabilities. For this example, let us assume that we are doing sentiment analysis for a sentence. The sentiment labels are one-hot vectors - positive $[0,0,1]$, negative $[1,0,0]$, neutral $[0,1,0]$. So when we pass in an input sentence through the above architecture, we get the probabilities of whether our input sentence is either positive, negative or neutral.
Forward pass
Let us now create a function Forwardprop for the foward pass through the above architecture. The function takes in a word embedding matrix $\boldsymbol{x}$, the Attention parameters $\boldsymbol{Q,Q_b,K,K_b,V,V_b}$ and the Feedforward parameters $\boldsymbol{W,b}$ as arguments.
function forwardprop(x,Q,Qb,K,Kb,V,Vb,W,b)
# Reshape from 1d to 2d
if ndims(x)==1
x = reshape(x,(:,1))
end
x = vcat(x,Vector(range(0,size(x)[2]-1))')
The if condition handles a sentence with just a single word by reshaping the 1D word embeddings to 2D. The vcat concatenates the word embeddings $x$ with its position embeddings (index position in the sentence).
# Return attention values
q,k,v,α,z = Attention(x,Q,Qb,K,Kb,V,Vb)
We next pass in the arguments to the Attention function as defined in Attention, which returns the queries, keys, values, attention weights and context vectors.
# Feed Forward layer
f = FeedForward(z,W,b) #shape: [3xn]
In the third layer, we pass the $z$ context vectors, currently a $d_vxn$ matrix through a feedforward layer. The Feedforward function is shown below.
function FeedForward(x,W,b)
return LinearTransform(x,W,b)
end
The feedforward layer transforms the context vectors $z$ to a $3xn$ matrix.
# Average pooling.
p = sum(f,dims=2)/size(f)[2] #shape: [3x1]
# Softmax layer to get probabilities
p = softmax(p)
end
Now, before passing it through to the softmax function, we need to reduce the 3xn matrix to a 3x1 matrix. A simple naive way would be to just average the $n$ columns. Hence, we next pass $f$ through a pooling layer to obtain a 3x1 vector before passing it through a softmax layer. The final values returns the probabilities of the sentiments.
Putting it all together,
function FeedForward(x,W,b)
return LinearTransform(x,W,b)
end
#Forward propagate
function forwardprop(x,Q,Qb,K,Kb,V,Vb,W,b)
# Reshape from 1d to 2d
if ndims(x)==1
x = reshape(x,(:,1))
end
x = vcat(x,Vector(range(0,size(x)[2]-1))')
# Return attention values
q,k,v,α,z = Attention(x,Q,Qb,K,Kb,V,Vb)
# Feed Forward layer
f = FeedForward(z,W,b) #shape: [3xn]
# Average pooling.
p = sum(f,dims=2)/size(f)[2] #shape: [3x1]
# Softmax layer to get probabilities
p = softmax(p)
return x,q,k,v,α,z,p
end
Backwards pass
This section dives into the details of computing gradients required for backpropagation.
Cross Entropy
As we have a softmax function at the last layer, a good loss function to use is the cross entropy loss, defined below. Note that mean squared error can work poorly with softmax, explained here.
function CrossEntropyLoss(z,x)
return -sum(x.*log.(z))
end
Working our way backwards, we first compute the derivative of loss with respect to the softmax, which is just the prediction p - the training label y.
\[\begin{aligned} \frac{\partial{L}}{\partial{\boldsymbol{p}}}=\boldsymbol{p}-\boldsymbol{y}\\ \end{aligned}\]function backprop(x,y,
Q,Qb,K,Kb,V,Vb,W,b,
q,k,v,α,z,p,
η=.001)
# Softmax gradient ∂L/∂σ
∂L_∂p = p-y #shape: [3x1]
Pooling
Recall that for the pooling layer, we averaged across all columns as below.
\[\begin{aligned} \boldsymbol{f}&=\begin{bmatrix}\boldsymbol{f_1} \ \boldsymbol{f_2} \ \boldsymbol{f_3} \ \ldots \boldsymbol{f_n} \end{bmatrix}\\ \boldsymbol{p}&=\frac{1}{n}\sum_{i=1}^{n}\boldsymbol{p_i}\\ &=\frac{1}{n}(\boldsymbol{f_1} + \boldsymbol{f_2} + \boldsymbol{f_3} + \ldots + \boldsymbol{f_n}) \\ \end{aligned}\tag{6}\label{eq:pooling}\]The derivative of (\ref{eq:pooling}) gives us the gradients
\[\begin{aligned} \frac{\partial{p}}{\partial{\boldsymbol{f}}} &=\frac{1}{n}\begin{bmatrix}1 \ 1 \ \ldots 1\\ \end{bmatrix} \end{aligned}\] \[\begin{equation} \frac{\partial{L}}{\partial{\boldsymbol{f}}}=\frac{\partial{L}}{\partial{\boldsymbol{p}}}\frac{\partial{\boldsymbol{p}}}{\partial{\boldsymbol{f}}}\\\tag{7}\label{eq:dLdf} \end{equation}\] # Average pooling gradient ∂L/∂f
∂p_∂f = (1 ./size(z)[2] .*ones(1,size(z)[2]))
∂L_∂f = ∂L_∂p*∂p_∂f #shape: [3xn]
Feedforward
The local gradients of $\boldsymbol{f}$ are \(\begin{equation}\frac{\partial{\boldsymbol{f}}}{\partial{\boldsymbol{z}}}=\boldsymbol{W}\\ \tag{8}\label{eq:dfdz}\end{equation}\)
\[\begin{equation} \frac{\partial{\boldsymbol{f}}}{\partial{\boldsymbol{W}}}=\boldsymbol{z}^T\\ \tag{9}\label{eq:dfdW} \end{equation}\]Using (\ref{eq:dLdf}) & (\ref{eq:dfdW}) \(\begin{equation} \frac{\partial{L}}{\partial{\boldsymbol{W}}}=\frac{\partial{L}}{\partial{\boldsymbol{f}}}\frac{\partial{\boldsymbol{f}}}{\partial{\boldsymbol{W}}}\\ \tag{10}\label{eq:dLdW} \end{equation}\) \(\begin{equation} \frac{\partial{L}}{\partial{\boldsymbol{b}}}=sumcol(\frac{\partial{L}}{\partial{\boldsymbol{f}}})\\ \tag{11}\label{eq:dLdb} \end{equation}\) We will need (\ref{eq:dLdW}) and (\ref{eq:dLdb}) later during the parameter update step.
# NN local gradients ∂f/∂z, ∂f/∂W
∂f_∂z = W #shape: [3xd]
∂f_∂W = z' #shape: [4xd]
# NN gradients ∂L/∂W and ∂L/∂b
∂L_∂W = ∂L_∂f*∂f_∂W #shape: [3xd]
∂L_∂b = sum(∂L_∂f,dims=2) #shape: [3x1]
Context vector
Using (\ref{eq:dLdf}) & (\ref{eq:dfdz}) \(\begin{equation} \frac{\partial{L}}{\partial{\boldsymbol{z}}}=\frac{\partial{L}}{\partial{\boldsymbol{f}}}\frac{\partial{\boldsymbol{f}}}{\partial{\boldsymbol{z}}}\\ \tag{12}\label{eq:dLdz} \end{equation}\)
# Context vector gradients
∂L_∂z = (∂L_∂f'*∂f_∂z)' #shape: [dxn]
The transpose operations ensures that the shape of the gradient matrix matches that of $\boldsymbol{z}$.
Attention
The below code computes the gradients for the attention layer. For more detailed explanations, please refer to this post.
# Attention gradients
# Local value gradients ∂z/∂v, ∂v/∂V
∂z_∂v = α #shape: [nxn]
∂v_∂V = x' #shape: [nxd]
# Local attention weight gradients ∂z/∂α
∂z_∂α = v #shape: [dxn]
# Initialize ∂α/∂e to zeroes
∂α_∂e = zeros(size(α)[1],size(α)[2]) #shape: [nxn]
# Derivative of softmax
for k in 1:size(α)[1]
for j in 1:size(α)[2]
if j == k
∂α_∂e[j,k] = α[j]*(1-α[j])
else
∂α_∂e[j,k] = -α[k]*α[j]
end
end
end
# Local query, key gradients ∂e_∂q, ∂e_∂k
∂e_∂q, ∂e_∂k = k', q' #shape: [nxd],[nxd]
∂q_∂Q, ∂k_∂K = x', x' #shape: [nxd],[nxd]
# Softmax gradients
∂L_∂α = ∂L_∂z'*∂z_∂α #shape: [nxn]
# Similarity score gradients
∂L_∂e = ∂L_∂α*∂α_∂e #shape: [nxn]
# query gradients
∂L_∂q = ∂L_∂e*∂e_∂q #shape: [nxd]
# key gradients
∂L_∂k = ∂L_∂e'*∂e_∂k #shape: [nxd]
# values gradients
∂L_∂v = ∂L_∂z*∂z_∂v #shape: [dxn]
# Q,K,V parameter gradients
∂L_∂Q = ∂L_∂q'*∂q_∂Q #shape: [dxd]
∂L_∂K = ∂L_∂k'*∂k_∂K #shape: [dxd]
∂L_∂V = ∂L_∂v*∂v_∂V #shape: [dxd]
∂L_∂Qb = sum(∂L_∂q',dims=2) #shape: [dx1]
∂L_∂Kb = sum(∂L_∂k',dims=2) #shape: [dx1]
∂L_∂Vb = sum(∂L_∂v,dims=2) #shape: [dx1]
Update step
The first block of code below initializes new parameters to the current ones. Then, using the gradients computed previously for attention and equations (\ref{eq:dLdW}) and (\ref{eq:dLdb}), the new $Q_{new},Qb_{new},K_{new},Kb_{new},V_{new},Vb_{new},W_{new},b_{new}$ can be updated using SGD.
# Update Attention parameters
# Initialize new parameter matrices with current parameters
Q_new = Q
Qb_new = Qb
K_new = K
Kb_new = Kb
V_new = V
Vb_new = Vb
W_new = W
b_new = b
# Update all trainable parameters with SGD
Q_new = Q_new .- η * ∂L_∂Q
Qb_new = Qb_new .- η * ∂L_∂Qb
K_new = K_new .- η * ∂L_∂K
Kb_new = Kb_new .- η * ∂L_∂Kb
V_new = V_new .- η * ∂L_∂V
Vb_new = Vb_new .- η * ∂L_∂Vb
W_new = W_new .- η * ∂L_∂W
b_new = b_new .- η * ∂L_∂b
backprop code
The full backprop function is as below.
# Backpropagate
function backprop(x,y,
Q,Qb,K,Kb,V,Vb,W,b,
q,k,v,α,z,p,
η=.001)
# Softmax gradient ∂L/∂σ
∂L_∂p = p-y #shape: [3x1]
# Average pooling gradient ∂L/∂f
∂p_∂f = (1 ./size(z)[2] .*ones(1,size(z)[2]))
∂L_∂f = ∂L_∂p*∂p_∂f #shape: [3xn]
# NN local gradients ∂f/∂z, ∂f/∂W
∂f_∂z = W #shape: [3xd]
∂f_∂W = z' #shape: [4xd]
# NN gradients ∂L/∂W and ∂L/∂b
∂L_∂W = ∂L_∂f*∂f_∂W #shape: [3xd]
∂L_∂b = sum(∂L_∂f,dims=2) #shape: [3x1]
# Context vector gradients
∂L_∂z = (∂L_∂f'*∂f_∂z)' #shape: [dxn]
# Attention gradients
# Local value gradients ∂z/∂v, ∂v/∂V
∂z_∂v = α #shape: [nxn]
∂v_∂V = x' #shape: [nxd]
# Local attention weight gradients ∂z/∂α
∂z_∂α = v #shape: [dxn]
# Initialize ∂α/∂e to zeroes
∂α_∂e = zeros(size(α)[1],size(α)[2]) #shape: [nxn]
# Derivative of softmax
for k in 1:size(α)[1]
for j in 1:size(α)[2]
if j == k
∂α_∂e[j,k] = α[j]*(1-α[j])
else
∂α_∂e[j,k] = -α[k]*α[j]
end
end
end
# Local query, key gradients ∂e_∂q, ∂e_∂k
∂e_∂q, ∂e_∂k = k', q' #shape: [nxd],[nxd]
∂q_∂Q, ∂k_∂K = x', x' #shape: [nxd],[nxd]
# Softmax gradients
∂L_∂α = ∂L_∂z'*∂z_∂α #shape: [nxn]
# Similarity score gradients
∂L_∂e = ∂L_∂α*∂α_∂e #shape: [nxn]
# query gradients
∂L_∂q = ∂L_∂e*∂e_∂q #shape: [nxd]
# key gradients
∂L_∂k = ∂L_∂e'*∂e_∂k #shape: [nxd]
# values gradients
∂L_∂v = ∂L_∂z*∂z_∂v #shape: [dxn]
# Q,K,V parameter gradients
∂L_∂Q = ∂L_∂q'*∂q_∂Q #shape: [dxd]
∂L_∂K = ∂L_∂k'*∂k_∂K #shape: [dxd]
∂L_∂V = ∂L_∂v*∂v_∂V #shape: [dxd]
∂L_∂Qb = sum(∂L_∂q',dims=2) #shape: [dx1]
∂L_∂Kb = sum(∂L_∂k',dims=2) #shape: [dx1]
∂L_∂Vb = sum(∂L_∂v,dims=2) #shape: [dx1]
# Update Attention parameters
# Initialize new parameter matrices with current parameters
Q_new = Q
Qb_new = Qb
K_new = K
Kb_new = Kb
V_new = V
Vb_new = Vb
W_new = W
b_new = b
# Update all trainable parameters with SGD
Q_new = Q_new .- η * ∂L_∂Q
Qb_new = Qb_new .- η * ∂L_∂Qb
K_new = K_new .- η * ∂L_∂K
Kb_new = Kb_new .- η * ∂L_∂Kb
V_new = V_new .- η * ∂L_∂V
Vb_new = Vb_new .- η * ∂L_∂Vb
W_new = W_new .- η * ∂L_∂W
b_new = b_new .- η * ∂L_∂b
return Q_new,Qb_new,K_new,Kb_new,V_new,Vb_new,W_new,b_new
end
Sentiment data
For our crude sentiment example, let us create a small sentiment dataset with first column containing the sentiment and second column containing the text. The contents for this “small.csv” are
sentiments,cleaned_review
positive,i love this speaker
negative,this mouse was waste of money it broke after using it
positive,great product
neutral,it meets my need
negative,the volume on this is lower than my flip big disappointment
positive,the best
negative,little disappointed and it only comes with charging cord not wall plug
positive,super
neutral,caught this on sale and for the price the quality is unbeatable
neutral,relax on my home
negative,the battery died in week doesn charge
negative,product is sub par so many things bad is that enough for me to submit it terrible
negative,had an older jbl portable speaker the charge is way more expensive and larger than my old one but not any louder was disappointed in that
neutral,easy to carry
neutral,the sound that comes out of this thing is incredible
positive,its strong clear and great
positive,loved it awesome
positive,i am happy with this product
negative,i am not happy with this
negative,it stopped working my is very disappointed
positive,i love the sound quality this is great product for the price
negative,bad audio input
negative,pretty decent but not the best for box
negative,does not work for ps bad quality not recommend
neutral,the clear sound
negative,my son only used this gaming headset for few months and the mic already quit working very disappointed
negative,the usb plug in is not long enough to connect to the playstation the cord is so long but splits into two six inch cords one goes into the playstation the other goes into the controller what the hell are you supposed to do
negative,they are uncomfortable and seem really fragile
positive,son loved them
neutral,thanks
negative,very frustrating as they both broke
negative,the color and shape are very nice but the mouse and the packaging arrived very dirty
negative,product lags a lot
positive,this is great service
positive,good quality product
negative,very frustrating bad quality
positive,we loved the soft texture
neutral,no sound
positive,good price
We next import in the packages CSV and DataFrames to allow us to load in the csv file.
using CSV, DataFrames
tb = CSV.read("small.csv",DataFrame)
To handle the case where the word in a sentence is not in our embeddings dictionary, for simplicity, we will just drop it using the following remove_nid function.
# Removes words that are not in dictionary
function remove_nid(sentence)
sen = []
if !ismissing(sentence)
for i in word_tokeniser(sentence)
try get_embeddings(i)
push!(sen,i)
catch e
end
end
end
return sen
end
Training
For training, we have the below train function which
- calls on forwardprop for the forward propagate step
- computes the Cross Entropy Loss with CrossEntropyLoss(p,y)
- backpropagates and returns the updated parameters with backprop(x,y,train_params…,q,k,v,α,z,p)
Note that we absorbed the train parameters with the julia splat (…) notation to keep the function readable.
# Train step
function train(x,y,train_params...)
x,q,k,v,α,z,p = forwardprop(x,train_params...)
CEloss = CrossEntropyLoss(p,y)
train_params = backprop(x,y,train_params...,q,k,v,α,z,p)
return train_params...,CEloss
end
The below code
- initializes small and random training parameters
- creates a dictionary that maps sentiments (positive, negative, neutral) to one-hot vectors
- processes our input sentences
- trains our minimal architecture
- returns loss
# main
using Random
# Random seed for reproducibility
rng = MersenneTwister(12);
# Initialize small random parameter values
Q = randn(rng, (51, 51))/100
Qb = zeros(51,1)
K = randn(rng, (51, 51))/100
Kb = zeros(51,1)
V = K
Vb = zeros(51,1)
W = randn(rng, (3, 51))/100
b = zeros(3,1)
# Sentiment dictionary that converts sentiment
# text into one-hot labels
sent_dict = Dict("positive"=>[0,0,1],"negative"=>[1,0,0],"neutral"=>[0,1,0])
#training
for epoch=1:1000
total_l = 0 #total loss
for idx in 1:nrow(tb)
x_em = []
l = 0 #current loss
sen = tb[idx,"cleaned_review"] #gets sentence
sen = remove_nid(sen) #remove words not in dictionary
if length(sen)!=0
for i in (sen)
if length(x_em) == 0
x_em = get_embeddings(i)
else
#Concatenate word embeddings along columns
x_em = hcat(x_em,get_embeddings(i))
end
end
#One hot vector sentiment
y = sent_dict[tb[idx,"sentiments"]]
#Update parameters
Q,Qb,K,Kb,V,Vb,W,b,l = train(x_em,y,Q,Qb,K,Kb,V,Vb,W,b)
end
total_l += l
end
println("Total loss:", total_l/nrow(tb))
end
To see the queries, keys, values being trained, we will deliberately not update the feedforward network parameters by commenting out the update steps for $\boldsymbol{W,b}$ in the backprop function.
W_new = W_new #.- η * ∂L_∂W
b_new = b_new #.- η * ∂L_∂b
Results
We are interested to see the resulting attention weights after training. To help us visualize that, we first import the julia Plots package.
using Plots
We then write a function evaluate_model which takes in a sentence sen, removes words not in our dictionary, transforms each word into embeddings and then forward propagating them to obtain the attention weights. With the attention weights, we can now plot a heatmap with the Plots package.
# Evaluates the sentiment given a sentence as input
function evaluate_model(sen)
x_em = []
sen = remove_nid(sen)
for i in (sen)
if length(x_em) == 0
x_em = get_embeddings(i)
else
x_em = hcat(x_em,get_embeddings(i))
end
end
α = forwardprop(x_em,Q,Qb,K,Kb,V,Vb,W,b)[5]
# plot heatmap of α
heatmap(sen,sen,α,clims=(0,1),aspect_ratio=1,color=:deepsea,
title="Attention weights α",grid="off")
end
When we run
evaluate_model("very sad as they both fail")
We can see from the below heatmap that the attention weights are “focusing” in on certain words.
evaluate_model("he loved that plug with good price ")
evaluate_model("terrible quality for this price")
evaluate_model("i love this fantastic product")
evaluate_model("easy to move around")
Full Code
Putting it all together
using CSV, DataFrames
using Plots
using Embeddings
using Random
tb = CSV.read("small.csv",DataFrame)
const embtable = load_embeddings(GloVe{:en},1,max_vocab_size=10000)
const get_word_index = Dict(word=>ii for (ii,word) in enumerate(embtable.vocab))
# Returns embeddings for word
function get_embeddings(word)
return embtable.embeddings[:,get_word_index[word]]
end
# Splits sentence into a vector of words
function word_tokeniser(sentence)
return split(sentence," ")
end
# Softmax function
function softmax(x)
x = x .- maximum(x)
return exp.(x) ./ sum(exp.(x))
end
# Cross Entropy Loss
function CrossEntropyLoss(z,x)
return -sum(x.*log.(z))
end
# Linear Transformation
function LinearTransform(x,W,b)
return W*x.+b
end
# Feedforward network
function FeedForward(x,W,b)
return LinearTransform(x,W,b)
end
# Return Attention Weights
function AttentionWeights(x,q,k,v)
# compute similarity between queries and keys (with scaling)
e = q'*k/sqrt(length(q))
# initialize attention weight matrix α with zeroes
α = zeros(size(e))
# normalize each similarity row with softmax
for row in 1:size(e)[1]
α[row,:] = softmax(e[row,:])
end
return α
end
# Attention block
function Attention(x,Q,Qb,K,Kb,V,Vb)
# queries
q = LinearTransform(x,Q,Qb)
# keys
k = LinearTransform(x,K,Kb)
# values
v = LinearTransform(x,V,Vb)
# Attention Weights
α = AttentionWeights(x,q,k,v)
# context vectors
z = v * α'
return q,k,v,α,z
end
# Forward propagate
function forwardprop(x,Q,Qb,K,Kb,V,Vb,W,b)
# Reshape from 1d to 2d
if ndims(x)==1
x = reshape(x,(:,1))
end
x = vcat(x,Vector(range(0,size(x)[2]-1))')
# Return attention values
q,k,v,α,z = Attention(x,Q,Qb,K,Kb,V,Vb)
# Feed Forward layer
f = FeedForward(z,W,b) #shape: [3xn]
# Average pooling.
p = sum(f,dims=2)/size(f)[2] #shape: [3x1]
# Softmax layer to get probabilities
p = softmax(p)
return x,q,k,v,α,z,p
end
# Train step
function train(x,y,train_params...)
x,q,k,v,α,z,p = forwardprop(x,train_params...)
CEloss = CrossEntropyLoss(p,y)
train_params = backprop(x,y,train_params...,q,k,v,α,z,p)
return train_params...,CEloss
end
# Backpropagate
function backprop(x,y,
Q,Qb,K,Kb,V,Vb,W,b,
q,k,v,α,z,p,
η=.001)
# Softmax gradient ∂L/∂σ
∂L_∂p = p-y #shape: [3x1]
# Average pooling gradient ∂L/∂f
∂p_∂f = (1 ./size(z)[2] .*ones(1,size(z)[2]))
∂L_∂f = ∂L_∂p*∂p_∂f #shape: [3xn]
# NN local gradients ∂f/∂z, ∂f/∂W
∂f_∂z = W #shape: [3xd]
∂f_∂W = z' #shape: [4xd]
# NN gradients ∂L/∂W and ∂L/∂b
∂L_∂W = ∂L_∂f*∂f_∂W #shape: [3xd]
∂L_∂b = sum(∂L_∂f,dims=2) #shape: [3x1]
# Context vector gradients
∂L_∂z = (∂L_∂f'*∂f_∂z)' #shape: [dxn]
# Attention gradients
# Local value gradients ∂z/∂v, ∂v/∂V
∂z_∂v = α #shape: [nxn]
∂v_∂V = x' #shape: [nxd]
# Local attention weight gradients ∂z/∂α
∂z_∂α = v #shape: [dxn]
# Initialize ∂α/∂e to zeroes
∂α_∂e = zeros(size(α)[1],size(α)[2]) #shape: [nxn]
# Derivative of softmax
for k in 1:size(α)[1]
for j in 1:size(α)[2]
if j == k
∂α_∂e[j,k] = α[j]*(1-α[j])
else
∂α_∂e[j,k] = -α[k]*α[j]
end
end
end
# Local query, key gradients ∂e_∂q, ∂e_∂k
∂e_∂q, ∂e_∂k = k', q' #shape: [nxd],[nxd]
∂q_∂Q, ∂k_∂K = x', x' #shape: [nxd],[nxd]
# Softmax gradients
∂L_∂α = ∂L_∂z'*∂z_∂α #shape: [nxn]
# Similarity score gradients
∂L_∂e = ∂L_∂α*∂α_∂e #shape: [nxn]
# query gradients
∂L_∂q = ∂L_∂e*∂e_∂q #shape: [nxd]
# key gradients
∂L_∂k = ∂L_∂e'*∂e_∂k #shape: [nxd]
# values gradients
∂L_∂v = ∂L_∂z*∂z_∂v #shape: [dxn]
# Q,K,V parameter gradients
∂L_∂Q = ∂L_∂q'*∂q_∂Q #shape: [dxd]
∂L_∂K = ∂L_∂k'*∂k_∂K #shape: [dxd]
∂L_∂V = ∂L_∂v*∂v_∂V #shape: [dxd]
∂L_∂Qb = sum(∂L_∂q',dims=2) #shape: [dx1]
∂L_∂Kb = sum(∂L_∂k',dims=2) #shape: [dx1]
∂L_∂Vb = sum(∂L_∂v,dims=2) #shape: [dx1]
# Update Attention parameters
# Initialize new parameter matrices with current parameters
Q_new = Q
Qb_new = Qb
K_new = K
Kb_new = Kb
V_new = V
Vb_new = Vb
W_new = W
b_new = b
# Update all trainable parameters with SGD
Q_new = Q_new .- η * ∂L_∂Q
Qb_new = Qb_new .- η * ∂L_∂Qb
K_new = K_new .- η * ∂L_∂K
Kb_new = Kb_new .- η * ∂L_∂Kb
V_new = V_new .- η * ∂L_∂V
Vb_new = Vb_new .- η * ∂L_∂Vb
W_new = W_new #.- η * ∂L_∂W
b_new = b_new #.- η * ∂L_∂b
return Q_new,Qb_new,K_new,Kb_new,V_new,Vb_new,W_new,b_new
end
# Removes words that are not in dictionary
function remove_nid(sentence)
sen = []
if !ismissing(sentence)
for i in word_tokeniser(sentence)
try get_embeddings(i)
push!(sen,i)
catch e
end
end
end
return sen
end
# Evaluates the sentiment given a sentence as input
function evaluate_model(sen)
x_em = []
sen = remove_nid(sen)
for i in (sen)
if length(x_em) == 0
x_em = get_embeddings(i)
else
x_em = hcat(x_em,get_embeddings(i))
end
end
α = forwardprop(x_em,Q,Qb,K,Kb,V,Vb,W,b)[5]
# plot heatmap of α
heatmap(sen,sen,α,clims=(0,1),aspect_ratio=1,color=:deepsea,
title="Attention weights α",grid="off")
end
# main
# Random seed for reproducibility
rng = MersenneTwister(12);
# Initialize small random parameter values
Q = randn(rng, (51, 51))/100
Qb = zeros(51,1)
K = randn(rng, (51, 51))/100
Kb = zeros(51,1)
V = K
Vb = zeros(51,1)
W = randn(rng, (3, 51))/100
b = zeros(3,1)
# Sentiment dictionary that converts sentiment
# text into one-hot labels
sent_dict = Dict("positive"=>[0,0,1],"negative"=>[1,0,0],"neutral"=>[0,1,0])
#training
for epoch=1:1000
total_l = 0 #total loss
for idx in 1:nrow(tb)
x_em = []
l = 0 #current loss
sen = tb[idx,"cleaned_review"] #gets sentence
sen = remove_nid(sen) #remove words not in dictionary
if length(sen)!=0
for i in (sen)
if length(x_em) == 0
x_em = get_embeddings(i)
else
#Concatenate word embeddings along columns
x_em = hcat(x_em,get_embeddings(i))
end
end
#One hot vector sentiment
y = sent_dict[tb[idx,"sentiments"]]
#Update parameters
Q,Qb,K,Kb,V,Vb,W,b,l = train(x_em,y,Q,Qb,K,Kb,V,Vb,W,b)
end
total_l += l
end
println("Total loss:", total_l/nrow(tb))
end
# vizualize attention weights
evaluate_model("very sad as they both fail")
And that is attention from scratch!
References
- http://neuralnetworksanddeeplearning.com/chap1.html
- https://cs231n.github.io/optimization-2/
- https://web.stanford.edu/class/archive/cs/cs224n/cs224n.1234/slides/cs224n-2023-lecture08-transformers.pdf
- https://kcin96.github.io/notes/ml/2024/01/18/backpropagation.html
- https://kcin96.github.io/notes/ml/2023/12/29/why-mean-squared-loss-works-poorly-with-softmax.html
- Attention Is All You Need
- Jeffrey Pennington, Richard Socher, and Christopher D. Manning. 2014. GloVe: Global Vectors for Word Representation.