Overview

This section explores backpropagating attention components (Query, Keys, Value, Softmax) individually.

Neural network recap

Equivalent to * + = W x b z

The above left diagram shows a 2 layer neural network with 3 inputs nodes and 4 output nodes. This network takes a weighted sum of the 3 inputs and outputs 4 numbers, where the connections between the 1st and last layer nodes represent the weights. This network can be represented by a matrix multiplication $\boldsymbol{W}\boldsymbol{x}+\boldsymbol{b}=\boldsymbol{z}$ (1) as shown on the right. Written in full vector form, \(\begin{equation} \begin{bmatrix} w_{11} \ w_{12} \ w_{13} \\ w_{21} \ w_{22} \ w_{23} \\ w_{31} \ w_{32} \ w_{33} \\ w_{41} \ w_{42} \ w_{43} \end{bmatrix} \ \begin{bmatrix} x_1 \\ x_2 \\ x_3 \end{bmatrix}+ \begin{bmatrix} b_1 \\ b_2 \\ b_3 \\ b_4 \end{bmatrix}= \begin{bmatrix} z_1 \\ z_2 \\ z_3 \\ z_4 \end{bmatrix} \end{equation}\)

  • Forward propagate step:

\(\begin{aligned} &\boldsymbol{z} = \boldsymbol{W} \boldsymbol{x}+\boldsymbol{b} \\ \end{aligned}\) \

  • Backpropagate step:

Let $\boldsymbol{y}$ be the 4x1 target vector, and using the L2 loss function, we can backprogate using the following computations \(\begin{aligned} &L = \frac{1}{2}||\boldsymbol{z} - \boldsymbol{y}||^2_2 \\ &\frac{\partial{L}}{\partial{\boldsymbol{z}}}=\boldsymbol{z}-\boldsymbol{y} \\ &\frac{\partial{\boldsymbol{z}}}{\partial{\boldsymbol{W}}}=\boldsymbol{x}^T,\frac{\partial{\boldsymbol{z}}}{\partial{\boldsymbol{b}}}=\boldsymbol{1} \\ &\begin{bmatrix} \frac{\partial{z_1}}{\partial{w_{11}}} \ \frac{\partial{z_1}}{\partial{w_{12}}} \ \frac{\partial{z_1}}{\partial{w_{13}}} \\ \frac{\partial{z_2}}{\partial{w_{21}}} \ \frac{\partial{z_2}}{\partial{w_{22}}} \ \frac{\partial{z_2}}{\partial{w_{23}}} \\ \frac{\partial{z_3}}{\partial{w_{31}}} \ \frac{\partial{z_3}}{\partial{w_{32}}} \ \frac{\partial{z_2}}{\partial{w_{33}}} \\ \frac{\partial{z_4}}{\partial{w_{41}}} \ \frac{\partial{z_4}}{\partial{w_{42}}} \ \frac{\partial{z_4}}{\partial{w_{43}}} \\ \end{bmatrix} = \begin{bmatrix} x_1 \ x_2 \ x_3 \\ x_1 \ x_2 \ x_3 \\ x_1 \ x_2 \ x_3 \\ x_1 \ x_2 \ x_3 \\ \end{bmatrix}, \ \begin{bmatrix} \frac{\partial{z_1}}{\partial{b_1}} \\ \frac{\partial{z_2}}{\partial{b_2}} \\ \frac{\partial{z_3}}{\partial{b_3}} \\ \frac{\partial{z_4}}{\partial{b_4}} \\ \end{bmatrix}= \begin{bmatrix} 1 \\ 1 \\ 1 \\ 1 \\\end{bmatrix}\\ \end{aligned}\)

  • Update step:
\[\begin{aligned} &\boldsymbol{W}^{new} = \boldsymbol{W}-\eta \frac{\partial{L}}{\partial{\boldsymbol{W}}}, \ \boldsymbol{W}^{new} = \boldsymbol{W}-\eta (\boldsymbol{z}-\boldsymbol{y}) \boldsymbol{x}^T \\ &\boldsymbol{b}^{new} = \boldsymbol{b}-\eta \frac{\partial{L}}{\partial{\boldsymbol{b}}}, \ \boldsymbol{b}^{new} = \boldsymbol{b}-\eta (\boldsymbol{z}-\boldsymbol{y}) \end{aligned}\]

We can simplify the expression (1) further by absorbing the bias vector $\boldsymbol{b}$ into the weight matrix $\boldsymbol{W}$ and adding a row of 1s into $\boldsymbol{x}$.

\[\begin{aligned} \boldsymbol{W}\boldsymbol{x}&=\boldsymbol{z}\\ \begin{bmatrix} w_{11} \ w_{12} \ w_{13} \ b_1 \\ w_{21} \ w_{22} \ w_{23} \ b_2 \\ w_{31} \ w_{32} \ w_{33} \ b_3 \\ w_{41} \ w_{42} \ w_{43} \ b_4 \end{bmatrix} \begin{bmatrix} x_{1} \\ x_{2} \\ x_{3} \\ 1 \end{bmatrix}&= \begin{bmatrix} z_{1} \\ z_{2} \\ z_{3} \\ z_{4} \end{bmatrix}\\ \end{aligned}\]

Multidimensional input

b1 b2 1 w11 w21 w12 w22 z11 z21 x21 x11 1 x21 x11 > > > > > > > > > > > z21 z11

We began the previous section by passing a column vector input $\boldsymbol{x}$ then passed it through the network above to get column vector $\boldsymbol{z}$. Now let us extend that concept further by assuming that we have a matrix input instead of just a column vector. The figure above shows a 3x3 matrix (3 column vectors concatenated together) input x (brown) passed though the network weights (blue) and bias (green) to produce the z output (red).

  • Forward propagate step:

\(\begin{aligned} \boldsymbol{W}\boldsymbol{x}=& \ \boldsymbol{z}\\ \begin{bmatrix} w_{11} \ w_{12} \ b_1 \\ w_{21} \ w_{22} \ b_2 \end{bmatrix} \begin{bmatrix} x_{11} \ x_{12} \ x_{13} \\ x_{21} \ x_{22} \ x_{23} \\ 1 \ \ \ 1 \ \ \ 1 \end{bmatrix}=& \begin{bmatrix} z_{11} \ z_{12} \ z_{13} \\ z_{21} \ z_{22} \ z_{23} \end{bmatrix}\\ \begin{bmatrix} w_{11}x_{11}+w_{12}x_{21}+b_1 \ \ \ w_{11}x_{12}+w_{12}x_{22}+b_1 \ \ \ w_{11}x_{13}+w_{12}x_{23}+b_1 \\ w_{21}x_{11}+w_{22}x_{21}+b_2 \ \ \ w_{21}x_{12}+w_{22}x_{22}+b_2 \ \ \ w_{21}x_{13}+w_{22}x_{23}+b_2 \end{bmatrix}=& \begin{bmatrix} z_{11} \ z_{12} \ z_{13} \\ z_{21} \ z_{22} \ z_{23} \end{bmatrix} \end{aligned}\)
The matrix multiplication operation can be thought of as passing 3 column vector inputs separately to the network and then concatenating the individual outputs next to each other.

  • Backpropagate step:

Let $\boldsymbol{y}$ be the 2x3 target vector, and using the L2 loss function, we calculate the error of a 2x3 $\boldsymbol{z}$ . We know how to do backpropagation with a column vector (previous section) so let us split the problem down to 3 steps - computing backpropagation with $\boldsymbol{z_1}$, then $\boldsymbol{z_2}$ then $\boldsymbol{z_3}$.

\[\begin{aligned} &L = \frac{1}{2}||\boldsymbol{z} - \boldsymbol{y}||^2_2 \\ &\frac{\partial{L}}{\partial{\boldsymbol{z}}}=\boldsymbol{z}-\boldsymbol{y} \\ \end{aligned}\]

Step 1

∂z11 ∂z21 x11∂z11 x11∂z21 x21∂z11 x21∂z21 ∂z11 ∂z21 < < < < < < < < ∂z21 ∂z11 ∂z22 ∂z12 ∂z23 ∂z13
  • Step 1 - Backpropagate step:
\[\begin{aligned} &\frac{\partial{\boldsymbol{z_1}}}{\partial{\boldsymbol{W}}}=\boldsymbol{x_1}^T,\frac{\partial{\boldsymbol{z_1}}}{\partial{\boldsymbol{b}}}=\boldsymbol{1} \\ &\begin{bmatrix} \frac{\partial{z_{11}}}{\partial{w_{11}}} \ \frac{\partial{z_{11}}}{\partial{w_{12}}} \\ \frac{\partial{z_{21}}}{\partial{w_{21}}} \ \frac{\partial{z_{21}}}{\partial{w_{22}}} \\ \end{bmatrix} = \begin{bmatrix} x_{11} \ x_{21} \\ x_{11} \ x_{21} \\ \end{bmatrix}, \ \begin{bmatrix} \frac{\partial{z_{11}}}{\partial{b_1}} \\ \frac{\partial{z_{21}}}{\partial{b_2}} \\ \end{bmatrix}= \begin{bmatrix} 1 \\ 1 \\\end{bmatrix}\\ \end{aligned}\]
  • Step 1 - Update step:
\[\begin{aligned} &\boldsymbol{W}^{new} = \boldsymbol{W}-\eta \frac{\partial{L}}{\partial{\boldsymbol{W}}}, \ \boldsymbol{W}^{new} = \boldsymbol{W}-\eta (\boldsymbol{z_1}-\boldsymbol{y_1}) \boldsymbol{x_1}^T \\ &\boldsymbol{b}^{new} = \boldsymbol{b}-\eta \frac{\partial{L}}{\partial{\boldsymbol{b}}}, \ \boldsymbol{b}^{new} = \boldsymbol{b}-\eta (\boldsymbol{z_1}-\boldsymbol{y_1}) \end{aligned}\]

Step 2

∂z12 ∂z22 x12∂z12 x12∂z22 x22∂z12 x22∂z22 ∂z12 ∂z22 < < < < < < < < ∂z21 ∂z11 ∂z22 ∂z12 ∂z23 ∂z13
  • Step 2 - Backpropagate step:
\[\begin{aligned} &\frac{\partial{\boldsymbol{z_2}}}{\partial{\boldsymbol{W}}}=\boldsymbol{x_2}^T,\frac{\partial{\boldsymbol{z_2}}}{\partial{\boldsymbol{b}}}=\boldsymbol{1} \\ &\begin{bmatrix} \frac{\partial{z_{12}}}{\partial{w_{11}}} \ \frac{\partial{z_{12}}}{\partial{w_{12}}} \\ \frac{\partial{z_{22}}}{\partial{w_{21}}} \ \frac{\partial{z_{21}}}{\partial{w_{22}}} \\ \end{bmatrix} = \begin{bmatrix} x_{12} \ x_{22} \\ x_{12} \ x_{22} \\ \end{bmatrix}, \ \begin{bmatrix} \frac{\partial{z_{12}}}{\partial{b_1}} \\ \frac{\partial{z_{22}}}{\partial{b_2}} \\ \end{bmatrix}= \begin{bmatrix} 1 \\ 1 \\\end{bmatrix}\\ \end{aligned}\]
  • Step 2 - Update step:
\[\begin{aligned} &\boldsymbol{W}^{new} = \boldsymbol{W}-\eta \frac{\partial{L}}{\partial{\boldsymbol{W}}}, \ \boldsymbol{W}^{new} = \boldsymbol{W}-\eta (\boldsymbol{z_2}-\boldsymbol{y_2}) \boldsymbol{x_2}^T \\ &\boldsymbol{b}^{new} = \boldsymbol{b}-\eta \frac{\partial{L}}{\partial{\boldsymbol{b}}}, \ \boldsymbol{b}^{new} = \boldsymbol{b}-\eta (\boldsymbol{z_2}-\boldsymbol{y_2}) \end{aligned}\]

Step 3

∂z13 ∂z23 x13∂z13 x13∂z23 x23∂z13 x23∂z23 ∂z13 ∂z23 < < < < < < < < ∂z21 ∂z11 ∂z22 ∂z12 ∂z23 ∂z13
  • Step 3 - Backpropagate step:
\[\begin{aligned} &\frac{\partial{\boldsymbol{z_3}}}{\partial{\boldsymbol{W}}}=\boldsymbol{x_3}^T,\frac{\partial{\boldsymbol{z_3}}}{\partial{\boldsymbol{b}}}=\boldsymbol{1} \\ &\begin{bmatrix} \frac{\partial{z_{13}}}{\partial{w_{11}}} \ \frac{\partial{z_{13}}}{\partial{w_{12}}} \\ \frac{\partial{z_{23}}}{\partial{w_{21}}} \ \frac{\partial{z_{23}}}{\partial{w_{22}}} \\ \end{bmatrix} = \begin{bmatrix} x_{13} \ x_{23} \\ x_{13} \ x_{23} \\ \end{bmatrix}, \ \begin{bmatrix} \frac{\partial{z_{13}}}{\partial{b_1}} \\ \frac{\partial{z_{23}}}{\partial{b_2}} \\ \end{bmatrix}= \begin{bmatrix} 1 \\ 1 \\\end{bmatrix}\\ \end{aligned}\]
  • Step 3 - Update step:
\[\begin{aligned} &\boldsymbol{W}^{new} = \boldsymbol{W}-\eta \frac{\partial{L}}{\partial{\boldsymbol{W}}}, \ \boldsymbol{W}^{new} = \boldsymbol{W}-\eta (\boldsymbol{z_3}-\boldsymbol{y_3}) \boldsymbol{x_3}^T \\ &\boldsymbol{b}^{new} = \boldsymbol{b}-\eta \frac{\partial{L}}{\partial{\boldsymbol{b}}}, \ \boldsymbol{b}^{new} = \boldsymbol{b}-\eta (\boldsymbol{z_3}-\boldsymbol{y_3}) \end{aligned}\]

If we stare at it closely, the 3 steps are like updating the $\boldsymbol{W}$ and the bias parameters 3 times. So we can actually combine all 3 steps into a matrix multiplication step. \(\begin{aligned} \begin{bmatrix} \frac{\partial{L}}{\partial{z_{11}}} \ \frac{\partial{L}}{\partial{z_{12}}} \ \frac{\partial{L}}{\partial{z_{13}}}\\ \frac{\partial{L}}{\partial{z_{21}}} \ \frac{\partial{L}}{\partial{z_{22}}} \ \frac{\partial{L}}{\partial{z_{23}}} \end{bmatrix} \begin{bmatrix} x_{11} \ x_{21} \\ x_{12} \ x_{22} \\ x_{13} \ x_{23} \\ \end{bmatrix}= \begin{bmatrix} \frac{\partial{L}}{\partial{z_{11}}}x_{11}+\frac{\partial{L}}{\partial{z_{12}}}x_{12}+\frac{\partial{L}}{\partial{z_{13}}}x_{13} \ \ \ \frac{\partial{L}}{\partial{z_{11}}}x_{21}+\frac{\partial{L}}{\partial{z_{12}}}x_{22}+\frac{\partial{L}}{\partial{z_{13}}}x_{23} \\ \frac{\partial{L}}{\partial{z_{21}}}x_{11}+\frac{\partial{L}}{\partial{z_{22}}}x_{12}+\frac{\partial{L}}{\partial{z_{23}}}x_{13} \ \ \ \frac{\partial{L}}{\partial{z_{21}}}x_{21}+\frac{\partial{L}}{\partial{z_{22}}}x_{22}+\frac{\partial{L}}{\partial{z_{23}}}x_{23} \\ \end{bmatrix} \end{aligned}\)

  • Update step:

\(\begin{aligned} &\boldsymbol{W}^{new} = \boldsymbol{W}-\eta \frac{\partial{L}}{\partial{\boldsymbol{W}}}, \ \boldsymbol{W}^{new} = \boldsymbol{W}-\eta (\boldsymbol{z}-\boldsymbol{y}) \boldsymbol{x}^T \\ &\boldsymbol{b}^{new} = \boldsymbol{b}-\eta \frac{\partial{L}}{\partial{\boldsymbol{b}}}, \ \boldsymbol{b}^{new} = \boldsymbol{b}-\eta \ sumcols(\boldsymbol{z}-\boldsymbol{y}) \end{aligned}\) where sumcols means summing the columns.

Values, V

α1 v1 * α1.v1 + z=α1.v1+α2.v2 α2.v2 * α2 v2
  • Forward propagate step:
\[\boldsymbol{v_1} =\begin{bmatrix} v_{11} \\ v_{21} \\ v_{31} \end{bmatrix}, \ \boldsymbol{v_2} = \begin{bmatrix} v_{12} \\ v_{22} \\ v_{32} \end{bmatrix}, \ \boldsymbol{\alpha} = \begin{bmatrix} \alpha_{1} \ \alpha_{2} \ \end{bmatrix}\\\]

Concatenating $\boldsymbol{v_1}$ and $\boldsymbol{v_2}$ to $\boldsymbol{v}$ \(\begin{aligned} \boldsymbol{v} = \begin{bmatrix} v_{11} \ v_{12} \\ v_{21} \ v_{22} \\ v_{31} \ v_{32} \end{bmatrix} \end{aligned}\)

$\boldsymbol{z}$ can be computed vectorially as \(\begin{aligned} &\boldsymbol{z} = \begin{bmatrix} v_{11} \ v_{12} \\ v_{21} \ v_{22} \\ v_{31} \ v_{32} \end{bmatrix} \begin{bmatrix} \alpha_{1} \\ \alpha_{2} \end{bmatrix}\\ &\boldsymbol{z} = \boldsymbol{v} \boldsymbol{\alpha}^T \end{aligned}\)

  • Backprogate step:

Let $\boldsymbol{y}$ be the 3x1 target vector.

\[\begin{aligned} &L = \frac{1}{2}||\boldsymbol{z} - \boldsymbol{y}||^2_2 \\ &\frac{\partial{L}}{\partial{\boldsymbol{z}}}=\boldsymbol{z}-\boldsymbol{y} \\ &\frac{\partial{L}}{\partial{\boldsymbol{z}}}=\begin{bmatrix} z_1 - y_1 \\ z_2 - y_2 \\ z_3 - y_3 \\ \end{bmatrix}\\ &\frac{\partial{\boldsymbol{z}}}{\partial{\boldsymbol{v}}}=\boldsymbol{\alpha}\\ &\begin{bmatrix} \frac{\partial{z_1}}{\partial{v_{11}}} \ \frac{\partial{z_1}}{\partial{v_{12}}} \\ \frac{\partial{z_2}}{\partial{v_{21}}} \ \frac{\partial{z_2}}{\partial{v_{22}}} \\ \frac{\partial{z_3}}{\partial{v_{31}}} \ \frac{\partial{z_3}}{\partial{v_{32}}} \\ \end{bmatrix} = \begin{bmatrix} \alpha_{1} \ \alpha_{2} \\ \alpha_{1} \ \alpha_{2} \\ \alpha_{1} \ \alpha_{2} \\ \end{bmatrix}, \\ &\frac{\partial{L}}{\partial{\boldsymbol{v}}}=\frac{\partial{L}}{\partial{\boldsymbol{z}}}\frac{\partial{\boldsymbol{z}}}{\partial{\boldsymbol{v}}}\\ \end{aligned}\] α1 v1 * α1.v1 + z=α1.v1+α2.v2 α2.v2 * α2 v2 ∂L/∂z ∂L/∂z ∂L/∂z α1.∂L/∂z α2.∂L/∂z * = z-y α (z-y)α

Intuitively from the left diagram, we can think of the errors being backpropagated to $v_1$ and $v_2$ by the scaling factors $\alpha_1$, $\alpha_2$ respectively.

  • Update step: \(\begin{aligned} &\boldsymbol{v}^{new}=\boldsymbol{v}-\eta\frac{\partial{L}}{\partial{\boldsymbol{v}}}\\ &\boldsymbol{v}^{new}=\boldsymbol{v}-\eta(\boldsymbol{z}-\boldsymbol{y})\boldsymbol{\alpha}\\ \end{aligned}\)

Next let us try to start stacking layers. For example, instead of $v_1$ & $v_2$ being the trainable parameters, they could be the output of a 2 layer feedforward network as shown below.

x1 x2 v1 v2 V weights α1 v1 * α1.v1 + z=α1.v1+α2.v2 α2.v2 * α2 v2 ( * V + x ) * b = α' z

In this case we have the matrix $\boldsymbol{V}$ as the trainable parameters and $\boldsymbol{v}=\boldsymbol{V}\boldsymbol{x}+\boldsymbol{b}$

  • Backprogate step:
\[\begin{aligned} &\frac{\partial{\boldsymbol{v}}}{\partial{\boldsymbol{V}}}=\boldsymbol{x}^T, &\frac{\partial{\boldsymbol{v}}}{\partial{\boldsymbol{b}}}=\boldsymbol{1}\\ &\frac{\partial{L}}{\partial{\boldsymbol{V}}}=\frac{\partial{L}}{\partial{\boldsymbol{z}}}\frac{\partial{\boldsymbol{z}}}{\partial{\boldsymbol{v}}}\frac{\partial{\boldsymbol{v}}}{\partial{\boldsymbol{V}}}\\ &\frac{\partial{L}}{\partial{\boldsymbol{V}}}= \begin{bmatrix} \frac{\partial{L}}{\partial{v_{11}}} \ \frac{\partial{L}}{\partial{v_{12}}} \\ \frac{\partial{L}}{\partial{v_{21}}} \ \frac{\partial{L}}{\partial{v_{22}}} \\ \frac{\partial{L}}{\partial{v_{31}}} \ \frac{\partial{L}}{\partial{v_{32}}} \\ \end{bmatrix} * \begin{bmatrix} x_{11} \ x_{21} \ x_{31}\\ x_{12} \ x_{22} \ x_{32}\\ \end{bmatrix} \\ &= \begin{bmatrix} \frac{\partial{L}}{\partial{v_{11}}}x_{11}+\frac{\partial{L}}{\partial{v_{12}}}x_{12} \ \ \frac{\partial{L}}{\partial{v_{11}}}x_{21}+\frac{\partial{L}}{\partial{v_{12}}}x_{22} \ \ \frac{\partial{L}}{\partial{v_{11}}}x_{31}+\frac{\partial{L}}{\partial{v_{12}}}x_{32} \\ \frac{\partial{L}}{\partial{v_{21}}}x_{11}+\frac{\partial{L}}{\partial{v_{22}}}x_{12} \ \ \frac{\partial{L}}{\partial{v_{21}}}x_{21}+\frac{\partial{L}}{\partial{v_{22}}}x_{22} \ \ \frac{\partial{L}}{\partial{v_{21}}}x_{31}+\frac{\partial{L}}{\partial{v_{22}}}x_{32} \\ \frac{\partial{L}}{\partial{v_{31}}}x_{11}+\frac{\partial{L}}{\partial{v_{32}}}x_{12} \ \ \frac{\partial{L}}{\partial{v_{31}}}x_{21}+\frac{\partial{L}}{\partial{v_{32}}}x_{22} \ \ \frac{\partial{L}}{\partial{v_{31}}}x_{31}+\frac{\partial{L}}{\partial{v_{32}}}x_{32} \\ \end{bmatrix} \end{aligned}\]
  • Update step:
\[\begin{aligned} &\boldsymbol{V}^{new}=\boldsymbol{V}-\eta\frac{\partial{L}}{\partial{\boldsymbol{V}}}\\ &\boldsymbol{V}^{new}=\boldsymbol{V}-\eta(\boldsymbol{z}-\boldsymbol{y})\boldsymbol{\alpha}\boldsymbol{x}^T\\ \end{aligned}\] Vnew = V -η ( (z-y)α * x' = (z-y)αx' )

Notice that the product of the purple and brown block has an implied sum of the change in $\boldsymbol{V}$ of both $v_1$, $v_2$. We can think of it as similar to combining 2 independent sample ‘iteration’ updates (pass $x_1$ in iteration 1 and then updating $\boldsymbol{V}$. Pass $x_2$ in iteration 2 and then updating $\boldsymbol{V}$)

\[\begin{aligned} &\boldsymbol{b}^{new}=\boldsymbol{b}-\eta\frac{\partial{L}}{\partial{\boldsymbol{b}}}\\ &\boldsymbol{b}^{new}=\boldsymbol{b}-\eta \ sumcols((\boldsymbol{z}-\boldsymbol{y})\boldsymbol{\alpha})\\ \end{aligned}\] bnew = b -η ( (z-y)α sum cols = )

To update the bias parameter, we will need to sum the 2 columns of $\frac{\partial{L}}{\partial{\boldsymbol{v}}}$. Reason being the bias is shared for both $x_1$, $x_2$. Think of it as combining 2 sample ‘iteration’ updates (pass $x_1$ in iteration 1 and then updating $\boldsymbol{V}$, $b$. Pass $x_2$ in iteration 2 and then updating $\boldsymbol{V}$, $b$).

The following julia code illustrates an example.

#L2 norm
function L2norm(x)
    sqrt(sum(x.^2)) 
end

#Squared Loss
function Squared_Loss(z,x)
    return 0.5*(L2norm(z-x))^2
end

#Computes output z
function FeedForward(W,b,x,y,α)
    v = W*x.+b
    z = v * α'
    return z
end

# Forward propagate
function forwardprop(W,b,x,y,α)
    z = FeedForward(W,b,x,y,α)
    return backprop(W,b,x,z,y,α)
end

# Backpropate
function backprop(W,b,x,z,y,α,η=0.002)
    println("Loss:",Squared_Loss(z,y))
    ∂L_∂z = z-y
    ∂z_∂v = α
    ∂v_∂w = x'

    ∂L_∂w = ∂L_∂z*∂z_∂v*∂v_∂w 
    ∂L_∂b = ∂L_∂z*∂z_∂v

    #init weights
    W_new = W
    b_new = b

    #update step
    W_new = W .- η * ∂L_∂w
    b_new = b .- η * sum(∂L_∂b,dims=2)

    return W_new, b_new
end

V = [-1 6 7;2 2 2;5 5 5]
V_bias = [0;0;0]
α = [5 -2]
x = [[1 3]; [4 .3]; [2 2]]
y = [-1; 20; 5;]
for i=1:10
    V, V_bias  = forwardprop(V,V_bias,x,y,α)
end
println("V:",V)
println("V_bias:",V_bias)
println("z:",FeedForward(V,V_bias,x,y,α))

Query, Q

x1 q1 Q weights q11 k11 * q11.k11 + e1=q11.k11+q21.k21 q21.k21 * q21 k21 k1 q11 k12 * q11.k12 + e2=q11.k12+q21.k22 q21.k22 * q21 k22 k2
  • Forward proprogate step:
\[\begin{aligned} &\boldsymbol{Q}\boldsymbol{x_1}=\boldsymbol{q_1}\\ &e_1=\boldsymbol{q_1}^T\boldsymbol{k_1}\\ &e_2=\boldsymbol{q_1}^T\boldsymbol{k_2}\\ &\boldsymbol{e}=\boldsymbol{q_1}^T\boldsymbol{k}\\ &where \ \boldsymbol{Q}=\begin{bmatrix} w_{11} \ w_{12} \ b_1 \\ w_{21} \ w_{22} \ b_2 \end{bmatrix}, \boldsymbol{x_1}=\begin{bmatrix} x_{11} \\ x_{21} \\ 1 \end{bmatrix}, \boldsymbol{q_1}=\begin{bmatrix} q_{11} \\ q_{21} \end{bmatrix}, \boldsymbol{k_1}=\begin{bmatrix} k_{11} \\ k_{21} \end{bmatrix}, \boldsymbol{k_2}=\begin{bmatrix} k_{12} \\ k_{22} \end{bmatrix}, \boldsymbol{k}=\begin{bmatrix} k_{11} \ k_{12} \\ k_{21} \ k_{22} \end{bmatrix}, \boldsymbol{e}=\begin{bmatrix} e_1 \ e_2 \end{bmatrix} \end{aligned}\]

Note that we have absorbed the bias terms $b_1,b_2$ into $Q$ and appended ‘1’ to the vector $\boldsymbol{x_1}$ for compactness. The above steps can be written concisely in vector form, \(\begin{aligned} &\boldsymbol{e}=(\boldsymbol{Qx_1})^T\boldsymbol{k}\\ &\boldsymbol{e}=\boldsymbol{q_1}^T\boldsymbol{k} \end{aligned}\)

* Q ( x1 )' * k = e
  • Backpropagate step:

Let $\boldsymbol{y}$ be the 2x1 target vector.

\[\begin{aligned} &L = \frac{1}{2}||\boldsymbol{e} - \boldsymbol{y}^T||^2_2 \\ &\frac{\partial{L}}{\partial{\boldsymbol{e}}}=\boldsymbol{e} -\boldsymbol{y}^T \\ &\frac{\partial{L}}{\partial{\boldsymbol{e}}}=\begin{bmatrix} e_1 - y_1 \ e_2 - y_2 \\ \end{bmatrix}\\ &\frac{\partial{\boldsymbol{e}}}{\partial{\boldsymbol{q_1}}}=\boldsymbol{k}^T\\ &\begin{bmatrix} \frac{\partial{e_1}}{\partial{q_{11}}} \ \frac{\partial{e_1}}{\partial{q_{21}}} \\ \frac{\partial{e_2}}{\partial{q_{11}}} \ \frac{\partial{e_2}}{\partial{q_{21}}} \\ \end{bmatrix} = \begin{bmatrix} k_{11} \ k_{21} \\ k_{12} \ k_{22} \\ \end{bmatrix}, \\ &\frac{\partial{\boldsymbol{q_1}}}{\partial{\boldsymbol{Q}}}=\boldsymbol{x_1}^T,\ &\frac{\partial{\boldsymbol{q_1}}}{\partial{\boldsymbol{b}}}=\boldsymbol{1}\\ &\frac{\partial{L}}{\partial{\boldsymbol{Q}}}=\frac{\partial{L}}{\partial{\boldsymbol{e}}}\frac{\partial{\boldsymbol{e}}}{\partial{\boldsymbol{q_1}}}\frac{\partial{\boldsymbol{q_1}}}{\partial{\boldsymbol{Q}}}\\ &\frac{\partial{L}}{\partial{\boldsymbol{q_1}}}=\frac{\partial{L}}{\partial{\boldsymbol{e}}}\frac{\partial{\boldsymbol{e}}}{\partial{\boldsymbol{q_1}}}=\begin{bmatrix} e_1 - y_1 \ e_2 - y_2 \\ \end{bmatrix}\begin{bmatrix} k_{11} \ k_{21} \\ k_{12} \ k_{22} \\ \end{bmatrix}\\ \end{aligned}\] x1 q1 Q weights q11 k11 * q11.k11 + e1=q11.k11+q21.k21 q21.k21 * q21 k21 k1 q11 k12 * q11.k12 + e2=q11.k12+q21.k22 q21.k22 * q21 k22 k2 ∂L/∂e1 ∂L/∂e2 ∂L/∂e1 ∂L/∂e1 ∂L/∂e2 ∂L/∂e2 k11∂L/∂e1 k21∂L/∂e1 k12∂L/∂e2 k22∂L/∂e2 k11∂L/∂e1+k12∂L/∂e2 k21∂L/∂e1+k22∂L/∂e2
  • Update step:
\[\begin{aligned} &\boldsymbol{Q}^{new}=\boldsymbol{Q}-\eta\frac{\partial{L}}{\partial{\boldsymbol{Q}}}\\ &\boldsymbol{Q}^{new}=\boldsymbol{Q}-\eta(\boldsymbol{e}-\boldsymbol{y}^T)\boldsymbol{k}^T\boldsymbol{x_1}^T\\ \end{aligned}\] Qnew = Q -η ( (e-y')k' * x' = (e-y')k'x' ) \[\begin{aligned} &\boldsymbol{b}^{new}=\boldsymbol{b}-\eta\frac{\partial{L}}{\partial{\boldsymbol{b}}}\\ &\boldsymbol{b}^{new}=\boldsymbol{b}-\eta \ ((\boldsymbol{e}-\boldsymbol{y}^T)\boldsymbol{k}^T)^T\\ \end{aligned}\] bnew = b -η ( ((e-y')k')' )'

The following julia code illustrates an example with 3x1 ks: $\boldsymbol{k_1},\boldsymbol{k_2},\boldsymbol{k_3}$, 3x4 $Q$ matrix and a 4x1 $\boldsymbol{x}$ input vector.

# Squared Loss
function Squared_Loss(z,x)
    return 0.5*(L2norm(z-x))^2
end

# L2 norm
function L2norm(x)
    sqrt(sum(x.^2))
end

#Feed forward
function FeedForward(Q,k,x,y)
    q = Q*x
    e = q'*k
    return q,e
end

# Forward propagate
function forwardprop(Q,k,x,y)
    q, e = FeedForward(Q,k,x,y)
    return backprop(Q,k,x,q,e,y)
end

# Backpropagate
function backprop(Q,k,x,q,e,y,η=.02)
    println(Squared_Loss(e',y))

    ∂L_∂e = (e-y')
    ∂e_∂q = k'
    ∂q_∂Q = x[1:end-1]'

    ∂L_∂q = ∂L_∂e*∂e_∂q
    ∂L_∂Q = ∂L_∂q.*∂q_∂Q
    ∂L_∂Qb = ∂L_∂q

    #init Q weights
    Q_new = Q

    #update step
    Q_new[:,1:end-1] = Q_new[:,1:end-1] .- η * ∂L_∂Q
    Q_new[:,end:end] = Q_new[:,end:end] .- η * ∂L_∂Qb' 

    return Q_new
end


Q = [1. 5 .2 0; 1 2 .4 0; 4 5 1 0;] #Last column is bias terms initialized to 0
k = [1 -3 6; 2 0 1; 4 5 1;]
x = [0.5; 0.5; .3; 1] #Last row is bias term set to 1
y = [-3.14; -6.3; 2.21;]

for i=1:100
    Q = forwardprop(Q,k,x,y)
end
println("Q:",Q)
println("e:",FeedForward(Q,k,x,y)[2])

Keys, K

x1 q1 Q weights x1 x2 k1 k2 K weights q11 k11 * q11.k11 + e1=q11.k11+q21.k21 q21.k21 * q21 k21 k1 q11 k12 * q11.k12 + e2=q11.k12+q21.k22 q21.k22 * q21 k22 k2
  • Forward proprogate step:
\[\begin{aligned} &\boldsymbol{K}\boldsymbol{x}=\boldsymbol{k}\\ &e_1=\boldsymbol{q_1}^T\boldsymbol{k_1}\\ &e_2=\boldsymbol{q_1}^T\boldsymbol{k_2}\\ &\boldsymbol{e}=\boldsymbol{q_1}^T\boldsymbol{k}\\ &where \ \boldsymbol{K}=\begin{bmatrix} w_{11} \ w_{12} \ b_1 \\ w_{21} \ w_{22} \ b_2 \end{bmatrix}, \boldsymbol{x}=\begin{bmatrix} \boldsymbol{x_1} \ \boldsymbol{x_2} \end{bmatrix}= \begin{bmatrix} x_{11} \ x_{12} \\ x_{21} \ x_{22} \\ 1 \ \ 1 \\ \end{bmatrix}, \boldsymbol{q_1}=\begin{bmatrix} q_{11} \\ q_{21} \end{bmatrix}, \boldsymbol{k_1}=\begin{bmatrix} k_{11} \\ k_{21} \end{bmatrix}, \boldsymbol{k_2}=\begin{bmatrix} k_{12} \\ k_{22} \end{bmatrix}, \boldsymbol{k}=\begin{bmatrix} k_{11} \ k_{12} \\ k_{21} \ k_{22} \end{bmatrix}, \boldsymbol{e}=\begin{bmatrix} e_1 \ e_2 \end{bmatrix} \end{aligned}\]

Note that as before we have absorbed the bias terms $b_1,b_2$ into $K$ and appended ‘1’s to the vectors $\boldsymbol{x_1}$, $\boldsymbol{x_2}$ for compactness. The above steps can be written concisely in vector form, \(\begin{aligned} &\boldsymbol{e}=(\boldsymbol{Qx_1})^T\boldsymbol{k}\\ &\boldsymbol{e}=\boldsymbol{q_1}^T\boldsymbol{k} \end{aligned}\)

  • Backpropagate step:

Let $\boldsymbol{y}$ be the 2x1 target vector.

\[\begin{aligned} &L = \frac{1}{2}||\boldsymbol{e} - \boldsymbol{y}^T||^2_2 \\ &\frac{\partial{L}}{\partial{\boldsymbol{e}}}=\boldsymbol{e} -\boldsymbol{y}^T \\ &\frac{\partial{L}}{\partial{\boldsymbol{e}}}=\begin{bmatrix} e_1 - y_1 \ e_2 - y_2 \\ \end{bmatrix}\\ &\frac{\partial{\boldsymbol{e}}}{\partial{\boldsymbol{k}}}=\boldsymbol{q_1}^T\\ &\begin{bmatrix} \frac{\partial{e_1}}{\partial{k_{11}}} \ \frac{\partial{e_1}}{\partial{k_{21}}} \\ \frac{\partial{e_2}}{\partial{k_{12}}} \ \frac{\partial{e_2}}{\partial{k_{22}}} \\ \end{bmatrix} = \begin{bmatrix} q_{11} \ q_{21} \\ q_{11} \ q_{21} \\ \end{bmatrix}, \\ &\frac{\partial{\boldsymbol{k}}}{\partial{\boldsymbol{K}}}=\boldsymbol{x}^T,\ &\frac{\partial{\boldsymbol{k}}}{\partial{\boldsymbol{b}}}=\boldsymbol{1}\\ &\frac{\partial{L}}{\partial{\boldsymbol{K}}}=\frac{\partial{L}}{\partial{\boldsymbol{e}}}\frac{\partial{\boldsymbol{e}}}{\partial{\boldsymbol{k}}}\frac{\partial{\boldsymbol{k}}}{\partial{\boldsymbol{K}}}\\ &\frac{\partial{L}}{\partial{\boldsymbol{k}}}=\frac{\partial{L}}{\partial{\boldsymbol{e}}}\frac{\partial{\boldsymbol{e}}}{\partial{\boldsymbol{k}}}=\begin{bmatrix} e_1 - y_1 \\ e_2 - y_2 \\ \end{bmatrix}\begin{bmatrix} q_{11} \ q_{21} \\ \end{bmatrix}\\ &\frac{\partial{L}}{\partial{\boldsymbol{k}}}= \begin{bmatrix} \frac{\partial{L}}{\partial{\boldsymbol{k_1}}}\\ \frac{\partial{L}}{\partial{\boldsymbol{k_2}}} \\ \end{bmatrix}= \begin{bmatrix} q_{11} (e_1 - y_1) \ \ q_{21} (e_1 - y_1) \\ q_{11}(e_2 - y_2) \ \ q_{21} (e_2 - y_2) \\ \end{bmatrix} \end{aligned}\] x1 q1 Q weights x1 x2 k1 k2 K weights q11 k11 * q11.k11 + e1=q11.k11+q21.k21 q21.k21 * q21 k21 k1 q11 k12 * q11.k12 + e2=q11.k12+q21.k22 q21.k22 * q21 k22 k2 x1 q1 Q weights q11 k11 * q11.k11 + e1=q11.k11+q21.k21 q21.k21 * q21 k21 k1 q11 k12 * q11.k12 + e2=q11.k12+q21.k22 q21.k22 * q21 k22 k2 ∂L/∂e1 ∂L/∂e2 ∂L/∂e1 ∂L/∂e1 ∂L/∂e2 ∂L/∂e2 q11∂L/∂e1 q21∂L/∂e1 q11∂L/∂e2 q21∂L/∂e2 [q11∂L/∂e1;q21∂L/∂e1] [q11∂L/∂e2;q21∂L/∂e2]
  • Update step:
\[\begin{aligned} &\boldsymbol{K}^{new}=\boldsymbol{K}-\eta\frac{\partial{L}}{\partial{\boldsymbol{K}}}\\ &\boldsymbol{K}^{new}=\boldsymbol{K}-\eta((\boldsymbol{e}-\boldsymbol{y}^T)^T\boldsymbol{q_1}^T)^T\boldsymbol{x^T}\\ \end{aligned}\] Knew = K -η ( q1(e-y') * x' = q1(e-y')x' ) \[\begin{aligned} &\boldsymbol{b}^{new}=\boldsymbol{b}-\eta\frac{\partial{L}}{\partial{\boldsymbol{b}}}\\ &\boldsymbol{b}^{new}=\boldsymbol{b}-\eta \ sumcols((\boldsymbol{e}-\boldsymbol{y}^T)^T\boldsymbol{q_1}^T)^T\\ \end{aligned}\] bnew = b -η ( sumcols(q1(e-y')) sum cols = )

The following julia code illustrates an example with 3x1 ks: $\boldsymbol{k_1},\boldsymbol{k_2},\boldsymbol{k_3}$, 3x4 $K$ matrix and a 4x3 $\boldsymbol{x}$ input vector.

# Squared Loss
function Squared_Loss(z,x)
    return 0.5*(L2norm(z-x))^2
end

# L2 norm
function L2norm(x)
    sqrt(sum(x.^2))
end

#Feed forward
function FeedForward(K,q,x,y)
    k = K*x
    e = q'*k
    return k,e
end

# Forward propagate
function forwardprop(K,q,x,y)
    k, e = FeedForward(K,q,x,y)
    return backprop(K,k,x,q,e,y)
end

# Backpropagate
function backprop(K,k,x,q,e,y,η=.01)
    println(Squared_Loss(e',y))
  
    ∂L_∂e = (e-y')'
    ∂e_∂k = q'
    ∂k_∂K = x[1:end-1,:]'

    ∂L_∂k = ∂L_∂e*∂e_∂k
    ∂L_∂K = ∂L_∂k'*∂k_∂K
    ∂L_∂Kb = ∂L_∂k'

    #init K weights
    K_new = K

    #update step
    K_new[:,1:end-1] = K_new[:,1:end-1] .- η * ∂L_∂K
    K_new[:,end:end] = K_new[:,end:end] .- η * sum(∂L_∂Kb,dims=2) 

    return K_new
end


K = [1 -3 6. 0; 2 0 1 0; 4 5 1 0;] #Last column is bias terms initialized to 0
q = [1.; 5; .2;]
x = [0.5 .2 .4; 0.5 .2 .6; .3 .25 .7; 1 1 1;] #Last row are bias terms set to 1
y = [-3.14; -6.3; 2.21;]

for i=1:300
    K = forwardprop(K,q,x,y)
end
println("K:",K)
println("e:",FeedForward(K,q,x,y)[2])

Softmax

e1 e2 softmax α1 α2 exp exp > > > > / / + a c b e1 e2 α1=exp(e1)/(exp(e1)+exp(e2)) α2=exp(e2)/(exp(e1)+exp(e2))

The left diagram above shows an abstraction of the softmax function (blackbox). If we pass in 2 values $e_1,e_2$ through the ‘blackbox’ softmax function we get 2 values $\alpha_1, \alpha_2$ out. The right diagram shows the circuit-like representation (internals) of the softmax function (blackbox).

  • Forward step \(\begin{aligned} softmax \ function, \sigma{(x_i)}=\frac{e^{x_i}}{\sum_{i=1}^2 {e^{x_i}}} \\ for \ i=1,2 \end{aligned}\)

  • Backpropagate step

\(\begin{aligned} \frac{\partial{L}}{\partial{\boldsymbol{\alpha}}}&= \begin{bmatrix} \frac{\partial{L}}{\partial{\alpha_1}} \ \frac{\partial{L}}{\partial{\alpha_2}} \end{bmatrix}\\ \frac{\partial{L}}{\partial{e_i}}&= \frac{\partial{L}}{\partial{\boldsymbol{\alpha}}}\frac{\partial{\boldsymbol{\alpha}}}{\partial{e_i}}\\ \frac{\partial{\boldsymbol{\alpha}}}{\partial{e_i}}&= \begin{bmatrix} \frac{\partial{\alpha_1}}{\partial{e_i}} \\ \frac{\partial{\alpha_2}}{\partial{e_i}} \end{bmatrix}\\ \frac{\partial{L}}{\partial{e_1}}&= \frac{\partial{L}}{\partial{\alpha_1}}\frac{\partial{\alpha_1}}{\partial{e_1}}+\frac{\partial{L}}{\partial{\alpha_2}}\frac{\partial{\alpha_2}}{\partial{e_1}}\\ &=\frac{\partial{L}}{\partial{\alpha_1}}(\alpha_1)(1-\alpha_1)+\frac{\partial{L}}{\partial{\alpha_2}}(-\alpha_1\alpha_2)\\ \frac{\partial{L}}{\partial{e_2}}&=\frac{\partial{L}}{\partial{\alpha_1}}\frac{\partial{\alpha_1}}{\partial{e_2}}+\frac{\partial{L}}{\partial{\alpha_2}}\frac{\partial{\alpha_2}}{\partial{e_2}}\\ &=\frac{\partial{L}}{\partial{\alpha_1}}(-\alpha_1\alpha_2)+\frac{\partial{L}}{\partial{\alpha_1}}(\alpha_2)(1-\alpha_2)\\ \end{aligned}\)

exp exp / / + < < < < < ∂α1/∂a ∂α1/∂b ∂α1/∂b ∂α2/∂b ∂α2/∂b ∂L/∂e1 ∂L/∂e2 ∂L/∂α1 ∂L/∂α2

To see this visually as signals flowing back, let $a=e^{e_1}, c=e^{e_1}, b=e^{e_1}+e^{e_2}, \alpha_1=\frac{a}{b},\alpha_2=\frac{c}{b}$. Next, let us rewrite the division operator as a product of $a$,$1/b$.
Refering to the above figure, we have red and blue arrows denoting signals from $\alpha_1$ and $\alpha_2$ respectively. Let us first zoom in on the $\alpha_1$ output. On the backward pass, the signal forks into 2 branches - top and mid branch. On the top branch, \(\begin{aligned} \frac{\partial{\alpha_1}}{\partial{a}}&=\frac{1}{b}\\ \frac{\partial{a}}{\partial{e_1}}&=e^{e_1}\\ (\frac{\partial{\alpha_1}}{\partial{e_1}})_{top}&=\frac{e^{e_1}}{b}\\ \end{aligned}\) On the mid branch, \(\begin{aligned} \frac{\partial{\alpha_1}}{\partial{b}}&=-\frac{a}{b^2}\\ \frac{\partial{b}}{\partial{e_1}}&=e^{e_1}\\ (\frac{\partial{\alpha_1}}{\partial{e_1}})_{mid}&=-\frac{e^{e_1}a}{b^2}\\ \end{aligned}\) As both paths converge at the ‘exp’ node we add them together giving \(\begin{aligned} \frac{\partial{\alpha_1}}{\partial{e_1}}&=\frac{e^{e_1}}{b}-\frac{e^{e_1}a}{b^2}\\ &=\frac{e^{e_1}}{b}(1-\frac{a}{b})\\ &=\alpha_1(1-\alpha_1) \end{aligned}\)

Now zooming in to the $\alpha_2$ output and working backwards, we see that we have a path to $e_1$ via the $b$ mid branch. On the mid branch, \(\begin{aligned} \frac{\partial{\alpha_2}}{\partial{b}}&=-\frac{c}{b^2}\\ \frac{\partial{b}}{\partial{e_1}}&=e^{e_1}\\ (\frac{\partial{\alpha_2}}{\partial{e_1}})_{mid}&=-\frac{e^{e_1}c}{b^2}\\ &=-\alpha_1\alpha_2 \end{aligned}\)

The same steps can be applied to obtain the gradients for $e_2$.

References

  1. http://neuralnetworksanddeeplearning.com/chap1.html
  2. https://cs231n.github.io/optimization-2/
  3. https://web.stanford.edu/class/archive/cs/cs224n/cs224n.1234/slides/cs224n-2023-lecture08-transformers.pdf