Backpropagation with Attention Components
Overview
This section explores backpropagating attention components (Query, Keys, Value, Softmax) individually.
Neural network recap
The above left diagram shows a 2 layer neural network with 3 inputs nodes and 4 output nodes. This network takes a weighted sum of the 3 inputs and outputs 4 numbers, where the connections between the 1st and last layer nodes represent the weights. This network can be represented by a matrix multiplication $\boldsymbol{W}\boldsymbol{x}+\boldsymbol{b}=\boldsymbol{z}$ (1) as shown on the right. Written in full vector form, \(\begin{equation} \begin{bmatrix} w_{11} \ w_{12} \ w_{13} \\ w_{21} \ w_{22} \ w_{23} \\ w_{31} \ w_{32} \ w_{33} \\ w_{41} \ w_{42} \ w_{43} \end{bmatrix} \ \begin{bmatrix} x_1 \\ x_2 \\ x_3 \end{bmatrix}+ \begin{bmatrix} b_1 \\ b_2 \\ b_3 \\ b_4 \end{bmatrix}= \begin{bmatrix} z_1 \\ z_2 \\ z_3 \\ z_4 \end{bmatrix} \end{equation}\)
- Forward propagate step:
\(\begin{aligned} &\boldsymbol{z} = \boldsymbol{W} \boldsymbol{x}+\boldsymbol{b} \\ \end{aligned}\) \
- Backpropagate step:
Let $\boldsymbol{y}$ be the 4x1 target vector, and using the L2 loss function, we can backprogate using the following computations \(\begin{aligned} &L = \frac{1}{2}||\boldsymbol{z} - \boldsymbol{y}||^2_2 \\ &\frac{\partial{L}}{\partial{\boldsymbol{z}}}=\boldsymbol{z}-\boldsymbol{y} \\ &\frac{\partial{\boldsymbol{z}}}{\partial{\boldsymbol{W}}}=\boldsymbol{x}^T,\frac{\partial{\boldsymbol{z}}}{\partial{\boldsymbol{b}}}=\boldsymbol{1} \\ &\begin{bmatrix} \frac{\partial{z_1}}{\partial{w_{11}}} \ \frac{\partial{z_1}}{\partial{w_{12}}} \ \frac{\partial{z_1}}{\partial{w_{13}}} \\ \frac{\partial{z_2}}{\partial{w_{21}}} \ \frac{\partial{z_2}}{\partial{w_{22}}} \ \frac{\partial{z_2}}{\partial{w_{23}}} \\ \frac{\partial{z_3}}{\partial{w_{31}}} \ \frac{\partial{z_3}}{\partial{w_{32}}} \ \frac{\partial{z_2}}{\partial{w_{33}}} \\ \frac{\partial{z_4}}{\partial{w_{41}}} \ \frac{\partial{z_4}}{\partial{w_{42}}} \ \frac{\partial{z_4}}{\partial{w_{43}}} \\ \end{bmatrix} = \begin{bmatrix} x_1 \ x_2 \ x_3 \\ x_1 \ x_2 \ x_3 \\ x_1 \ x_2 \ x_3 \\ x_1 \ x_2 \ x_3 \\ \end{bmatrix}, \ \begin{bmatrix} \frac{\partial{z_1}}{\partial{b_1}} \\ \frac{\partial{z_2}}{\partial{b_2}} \\ \frac{\partial{z_3}}{\partial{b_3}} \\ \frac{\partial{z_4}}{\partial{b_4}} \\ \end{bmatrix}= \begin{bmatrix} 1 \\ 1 \\ 1 \\ 1 \\\end{bmatrix}\\ \end{aligned}\)
- Update step:
We can simplify the expression (1) further by absorbing the bias vector $\boldsymbol{b}$ into the weight matrix $\boldsymbol{W}$ and adding a row of 1s into $\boldsymbol{x}$.
\[\begin{aligned} \boldsymbol{W}\boldsymbol{x}&=\boldsymbol{z}\\ \begin{bmatrix} w_{11} \ w_{12} \ w_{13} \ b_1 \\ w_{21} \ w_{22} \ w_{23} \ b_2 \\ w_{31} \ w_{32} \ w_{33} \ b_3 \\ w_{41} \ w_{42} \ w_{43} \ b_4 \end{bmatrix} \begin{bmatrix} x_{1} \\ x_{2} \\ x_{3} \\ 1 \end{bmatrix}&= \begin{bmatrix} z_{1} \\ z_{2} \\ z_{3} \\ z_{4} \end{bmatrix}\\ \end{aligned}\]Multidimensional input
We began the previous section by passing a column vector input $\boldsymbol{x}$ then passed it through the network above to get column vector $\boldsymbol{z}$. Now let us extend that concept further by assuming that we have a matrix input instead of just a column vector. The figure above shows a 3x3 matrix (3 column vectors concatenated together) input x (brown) passed though the network weights (blue) and bias (green) to produce the z output (red).
- Forward propagate step:
\(\begin{aligned}
\boldsymbol{W}\boldsymbol{x}=& \ \boldsymbol{z}\\
\begin{bmatrix}
w_{11} \ w_{12} \ b_1 \\ w_{21} \ w_{22} \ b_2
\end{bmatrix}
\begin{bmatrix}
x_{11} \ x_{12} \ x_{13} \\ x_{21} \ x_{22} \ x_{23} \\ 1 \ \ \ 1 \ \ \ 1
\end{bmatrix}=&
\begin{bmatrix}
z_{11} \ z_{12} \ z_{13} \\ z_{21} \ z_{22} \ z_{23}
\end{bmatrix}\\
\begin{bmatrix}
w_{11}x_{11}+w_{12}x_{21}+b_1 \ \ \ w_{11}x_{12}+w_{12}x_{22}+b_1 \ \ \ w_{11}x_{13}+w_{12}x_{23}+b_1 \\
w_{21}x_{11}+w_{22}x_{21}+b_2 \ \ \ w_{21}x_{12}+w_{22}x_{22}+b_2 \ \ \ w_{21}x_{13}+w_{22}x_{23}+b_2
\end{bmatrix}=&
\begin{bmatrix}
z_{11} \ z_{12} \ z_{13} \\ z_{21} \ z_{22} \ z_{23}
\end{bmatrix}
\end{aligned}\)
The matrix multiplication operation can be thought of as passing 3 column vector inputs separately to the network and then concatenating the individual outputs next to each other.
- Backpropagate step:
Let $\boldsymbol{y}$ be the 2x3 target vector, and using the L2 loss function, we calculate the error of a 2x3 $\boldsymbol{z}$ . We know how to do backpropagation with a column vector (previous section) so let us split the problem down to 3 steps - computing backpropagation with $\boldsymbol{z_1}$, then $\boldsymbol{z_2}$ then $\boldsymbol{z_3}$.
\[\begin{aligned} &L = \frac{1}{2}||\boldsymbol{z} - \boldsymbol{y}||^2_2 \\ &\frac{\partial{L}}{\partial{\boldsymbol{z}}}=\boldsymbol{z}-\boldsymbol{y} \\ \end{aligned}\]Step 1
- Step 1 - Backpropagate step:
- Step 1 - Update step:
Step 2
- Step 2 - Backpropagate step:
- Step 2 - Update step:
Step 3
- Step 3 - Backpropagate step:
- Step 3 - Update step:
If we stare at it closely, the 3 steps are like updating the $\boldsymbol{W}$ and the bias parameters 3 times. So we can actually combine all 3 steps into a matrix multiplication step. \(\begin{aligned} \begin{bmatrix} \frac{\partial{L}}{\partial{z_{11}}} \ \frac{\partial{L}}{\partial{z_{12}}} \ \frac{\partial{L}}{\partial{z_{13}}}\\ \frac{\partial{L}}{\partial{z_{21}}} \ \frac{\partial{L}}{\partial{z_{22}}} \ \frac{\partial{L}}{\partial{z_{23}}} \end{bmatrix} \begin{bmatrix} x_{11} \ x_{21} \\ x_{12} \ x_{22} \\ x_{13} \ x_{23} \\ \end{bmatrix}= \begin{bmatrix} \frac{\partial{L}}{\partial{z_{11}}}x_{11}+\frac{\partial{L}}{\partial{z_{12}}}x_{12}+\frac{\partial{L}}{\partial{z_{13}}}x_{13} \ \ \ \frac{\partial{L}}{\partial{z_{11}}}x_{21}+\frac{\partial{L}}{\partial{z_{12}}}x_{22}+\frac{\partial{L}}{\partial{z_{13}}}x_{23} \\ \frac{\partial{L}}{\partial{z_{21}}}x_{11}+\frac{\partial{L}}{\partial{z_{22}}}x_{12}+\frac{\partial{L}}{\partial{z_{23}}}x_{13} \ \ \ \frac{\partial{L}}{\partial{z_{21}}}x_{21}+\frac{\partial{L}}{\partial{z_{22}}}x_{22}+\frac{\partial{L}}{\partial{z_{23}}}x_{23} \\ \end{bmatrix} \end{aligned}\)
- Update step:
\(\begin{aligned} &\boldsymbol{W}^{new} = \boldsymbol{W}-\eta \frac{\partial{L}}{\partial{\boldsymbol{W}}}, \ \boldsymbol{W}^{new} = \boldsymbol{W}-\eta (\boldsymbol{z}-\boldsymbol{y}) \boldsymbol{x}^T \\ &\boldsymbol{b}^{new} = \boldsymbol{b}-\eta \frac{\partial{L}}{\partial{\boldsymbol{b}}}, \ \boldsymbol{b}^{new} = \boldsymbol{b}-\eta \ sumcols(\boldsymbol{z}-\boldsymbol{y}) \end{aligned}\) where sumcols means summing the columns.
Values, V
- Forward propagate step:
Concatenating $\boldsymbol{v_1}$ and $\boldsymbol{v_2}$ to $\boldsymbol{v}$ \(\begin{aligned} \boldsymbol{v} = \begin{bmatrix} v_{11} \ v_{12} \\ v_{21} \ v_{22} \\ v_{31} \ v_{32} \end{bmatrix} \end{aligned}\)
$\boldsymbol{z}$ can be computed vectorially as \(\begin{aligned} &\boldsymbol{z} = \begin{bmatrix} v_{11} \ v_{12} \\ v_{21} \ v_{22} \\ v_{31} \ v_{32} \end{bmatrix} \begin{bmatrix} \alpha_{1} \\ \alpha_{2} \end{bmatrix}\\ &\boldsymbol{z} = \boldsymbol{v} \boldsymbol{\alpha}^T \end{aligned}\)
- Backprogate step:
Let $\boldsymbol{y}$ be the 3x1 target vector.
\[\begin{aligned} &L = \frac{1}{2}||\boldsymbol{z} - \boldsymbol{y}||^2_2 \\ &\frac{\partial{L}}{\partial{\boldsymbol{z}}}=\boldsymbol{z}-\boldsymbol{y} \\ &\frac{\partial{L}}{\partial{\boldsymbol{z}}}=\begin{bmatrix} z_1 - y_1 \\ z_2 - y_2 \\ z_3 - y_3 \\ \end{bmatrix}\\ &\frac{\partial{\boldsymbol{z}}}{\partial{\boldsymbol{v}}}=\boldsymbol{\alpha}\\ &\begin{bmatrix} \frac{\partial{z_1}}{\partial{v_{11}}} \ \frac{\partial{z_1}}{\partial{v_{12}}} \\ \frac{\partial{z_2}}{\partial{v_{21}}} \ \frac{\partial{z_2}}{\partial{v_{22}}} \\ \frac{\partial{z_3}}{\partial{v_{31}}} \ \frac{\partial{z_3}}{\partial{v_{32}}} \\ \end{bmatrix} = \begin{bmatrix} \alpha_{1} \ \alpha_{2} \\ \alpha_{1} \ \alpha_{2} \\ \alpha_{1} \ \alpha_{2} \\ \end{bmatrix}, \\ &\frac{\partial{L}}{\partial{\boldsymbol{v}}}=\frac{\partial{L}}{\partial{\boldsymbol{z}}}\frac{\partial{\boldsymbol{z}}}{\partial{\boldsymbol{v}}}\\ \end{aligned}\]Intuitively from the left diagram, we can think of the errors being backpropagated to $v_1$ and $v_2$ by the scaling factors $\alpha_1$, $\alpha_2$ respectively.
- Update step: \(\begin{aligned} &\boldsymbol{v}^{new}=\boldsymbol{v}-\eta\frac{\partial{L}}{\partial{\boldsymbol{v}}}\\ &\boldsymbol{v}^{new}=\boldsymbol{v}-\eta(\boldsymbol{z}-\boldsymbol{y})\boldsymbol{\alpha}\\ \end{aligned}\)
Next let us try to start stacking layers. For example, instead of $v_1$ & $v_2$ being the trainable parameters, they could be the output of a 2 layer feedforward network as shown below.
In this case we have the matrix $\boldsymbol{V}$ as the trainable parameters and $\boldsymbol{v}=\boldsymbol{V}\boldsymbol{x}+\boldsymbol{b}$
- Backprogate step:
- Update step:
Notice that the product of the purple and brown block has an implied sum of the change in $\boldsymbol{V}$ of both $v_1$, $v_2$. We can think of it as similar to combining 2 independent sample ‘iteration’ updates (pass $x_1$ in iteration 1 and then updating $\boldsymbol{V}$. Pass $x_2$ in iteration 2 and then updating $\boldsymbol{V}$)
\[\begin{aligned} &\boldsymbol{b}^{new}=\boldsymbol{b}-\eta\frac{\partial{L}}{\partial{\boldsymbol{b}}}\\ &\boldsymbol{b}^{new}=\boldsymbol{b}-\eta \ sumcols((\boldsymbol{z}-\boldsymbol{y})\boldsymbol{\alpha})\\ \end{aligned}\]To update the bias parameter, we will need to sum the 2 columns of $\frac{\partial{L}}{\partial{\boldsymbol{v}}}$. Reason being the bias is shared for both $x_1$, $x_2$. Think of it as combining 2 sample ‘iteration’ updates (pass $x_1$ in iteration 1 and then updating $\boldsymbol{V}$, $b$. Pass $x_2$ in iteration 2 and then updating $\boldsymbol{V}$, $b$).
The following julia code illustrates an example.
#L2 norm
function L2norm(x)
sqrt(sum(x.^2))
end
#Squared Loss
function Squared_Loss(z,x)
return 0.5*(L2norm(z-x))^2
end
#Computes output z
function FeedForward(W,b,x,y,α)
v = W*x.+b
z = v * α'
return z
end
# Forward propagate
function forwardprop(W,b,x,y,α)
z = FeedForward(W,b,x,y,α)
return backprop(W,b,x,z,y,α)
end
# Backpropate
function backprop(W,b,x,z,y,α,η=0.002)
println("Loss:",Squared_Loss(z,y))
∂L_∂z = z-y
∂z_∂v = α
∂v_∂w = x'
∂L_∂w = ∂L_∂z*∂z_∂v*∂v_∂w
∂L_∂b = ∂L_∂z*∂z_∂v
#init weights
W_new = W
b_new = b
#update step
W_new = W .- η * ∂L_∂w
b_new = b .- η * sum(∂L_∂b,dims=2)
return W_new, b_new
end
V = [-1 6 7;2 2 2;5 5 5]
V_bias = [0;0;0]
α = [5 -2]
x = [[1 3]; [4 .3]; [2 2]]
y = [-1; 20; 5;]
for i=1:10
V, V_bias = forwardprop(V,V_bias,x,y,α)
end
println("V:",V)
println("V_bias:",V_bias)
println("z:",FeedForward(V,V_bias,x,y,α))
Query, Q
- Forward proprogate step:
Note that we have absorbed the bias terms $b_1,b_2$ into $Q$ and appended ‘1’ to the vector $\boldsymbol{x_1}$ for compactness. The above steps can be written concisely in vector form, \(\begin{aligned} &\boldsymbol{e}=(\boldsymbol{Qx_1})^T\boldsymbol{k}\\ &\boldsymbol{e}=\boldsymbol{q_1}^T\boldsymbol{k} \end{aligned}\)
- Backpropagate step:
Let $\boldsymbol{y}$ be the 2x1 target vector.
\[\begin{aligned} &L = \frac{1}{2}||\boldsymbol{e} - \boldsymbol{y}^T||^2_2 \\ &\frac{\partial{L}}{\partial{\boldsymbol{e}}}=\boldsymbol{e} -\boldsymbol{y}^T \\ &\frac{\partial{L}}{\partial{\boldsymbol{e}}}=\begin{bmatrix} e_1 - y_1 \ e_2 - y_2 \\ \end{bmatrix}\\ &\frac{\partial{\boldsymbol{e}}}{\partial{\boldsymbol{q_1}}}=\boldsymbol{k}^T\\ &\begin{bmatrix} \frac{\partial{e_1}}{\partial{q_{11}}} \ \frac{\partial{e_1}}{\partial{q_{21}}} \\ \frac{\partial{e_2}}{\partial{q_{11}}} \ \frac{\partial{e_2}}{\partial{q_{21}}} \\ \end{bmatrix} = \begin{bmatrix} k_{11} \ k_{21} \\ k_{12} \ k_{22} \\ \end{bmatrix}, \\ &\frac{\partial{\boldsymbol{q_1}}}{\partial{\boldsymbol{Q}}}=\boldsymbol{x_1}^T,\ &\frac{\partial{\boldsymbol{q_1}}}{\partial{\boldsymbol{b}}}=\boldsymbol{1}\\ &\frac{\partial{L}}{\partial{\boldsymbol{Q}}}=\frac{\partial{L}}{\partial{\boldsymbol{e}}}\frac{\partial{\boldsymbol{e}}}{\partial{\boldsymbol{q_1}}}\frac{\partial{\boldsymbol{q_1}}}{\partial{\boldsymbol{Q}}}\\ &\frac{\partial{L}}{\partial{\boldsymbol{q_1}}}=\frac{\partial{L}}{\partial{\boldsymbol{e}}}\frac{\partial{\boldsymbol{e}}}{\partial{\boldsymbol{q_1}}}=\begin{bmatrix} e_1 - y_1 \ e_2 - y_2 \\ \end{bmatrix}\begin{bmatrix} k_{11} \ k_{21} \\ k_{12} \ k_{22} \\ \end{bmatrix}\\ \end{aligned}\]- Update step:
The following julia code illustrates an example with 3x1 ks: $\boldsymbol{k_1},\boldsymbol{k_2},\boldsymbol{k_3}$, 3x4 $Q$ matrix and a 4x1 $\boldsymbol{x}$ input vector.
# Squared Loss
function Squared_Loss(z,x)
return 0.5*(L2norm(z-x))^2
end
# L2 norm
function L2norm(x)
sqrt(sum(x.^2))
end
#Feed forward
function FeedForward(Q,k,x,y)
q = Q*x
e = q'*k
return q,e
end
# Forward propagate
function forwardprop(Q,k,x,y)
q, e = FeedForward(Q,k,x,y)
return backprop(Q,k,x,q,e,y)
end
# Backpropagate
function backprop(Q,k,x,q,e,y,η=.02)
println(Squared_Loss(e',y))
∂L_∂e = (e-y')
∂e_∂q = k'
∂q_∂Q = x[1:end-1]'
∂L_∂q = ∂L_∂e*∂e_∂q
∂L_∂Q = ∂L_∂q.*∂q_∂Q
∂L_∂Qb = ∂L_∂q
#init Q weights
Q_new = Q
#update step
Q_new[:,1:end-1] = Q_new[:,1:end-1] .- η * ∂L_∂Q
Q_new[:,end:end] = Q_new[:,end:end] .- η * ∂L_∂Qb'
return Q_new
end
Q = [1. 5 .2 0; 1 2 .4 0; 4 5 1 0;] #Last column is bias terms initialized to 0
k = [1 -3 6; 2 0 1; 4 5 1;]
x = [0.5; 0.5; .3; 1] #Last row is bias term set to 1
y = [-3.14; -6.3; 2.21;]
for i=1:100
Q = forwardprop(Q,k,x,y)
end
println("Q:",Q)
println("e:",FeedForward(Q,k,x,y)[2])
Keys, K
- Forward proprogate step:
Note that as before we have absorbed the bias terms $b_1,b_2$ into $K$ and appended ‘1’s to the vectors $\boldsymbol{x_1}$, $\boldsymbol{x_2}$ for compactness. The above steps can be written concisely in vector form, \(\begin{aligned} &\boldsymbol{e}=(\boldsymbol{Qx_1})^T\boldsymbol{k}\\ &\boldsymbol{e}=\boldsymbol{q_1}^T\boldsymbol{k} \end{aligned}\)
- Backpropagate step:
Let $\boldsymbol{y}$ be the 2x1 target vector.
\[\begin{aligned} &L = \frac{1}{2}||\boldsymbol{e} - \boldsymbol{y}^T||^2_2 \\ &\frac{\partial{L}}{\partial{\boldsymbol{e}}}=\boldsymbol{e} -\boldsymbol{y}^T \\ &\frac{\partial{L}}{\partial{\boldsymbol{e}}}=\begin{bmatrix} e_1 - y_1 \ e_2 - y_2 \\ \end{bmatrix}\\ &\frac{\partial{\boldsymbol{e}}}{\partial{\boldsymbol{k}}}=\boldsymbol{q_1}^T\\ &\begin{bmatrix} \frac{\partial{e_1}}{\partial{k_{11}}} \ \frac{\partial{e_1}}{\partial{k_{21}}} \\ \frac{\partial{e_2}}{\partial{k_{12}}} \ \frac{\partial{e_2}}{\partial{k_{22}}} \\ \end{bmatrix} = \begin{bmatrix} q_{11} \ q_{21} \\ q_{11} \ q_{21} \\ \end{bmatrix}, \\ &\frac{\partial{\boldsymbol{k}}}{\partial{\boldsymbol{K}}}=\boldsymbol{x}^T,\ &\frac{\partial{\boldsymbol{k}}}{\partial{\boldsymbol{b}}}=\boldsymbol{1}\\ &\frac{\partial{L}}{\partial{\boldsymbol{K}}}=\frac{\partial{L}}{\partial{\boldsymbol{e}}}\frac{\partial{\boldsymbol{e}}}{\partial{\boldsymbol{k}}}\frac{\partial{\boldsymbol{k}}}{\partial{\boldsymbol{K}}}\\ &\frac{\partial{L}}{\partial{\boldsymbol{k}}}=\frac{\partial{L}}{\partial{\boldsymbol{e}}}\frac{\partial{\boldsymbol{e}}}{\partial{\boldsymbol{k}}}=\begin{bmatrix} e_1 - y_1 \\ e_2 - y_2 \\ \end{bmatrix}\begin{bmatrix} q_{11} \ q_{21} \\ \end{bmatrix}\\ &\frac{\partial{L}}{\partial{\boldsymbol{k}}}= \begin{bmatrix} \frac{\partial{L}}{\partial{\boldsymbol{k_1}}}\\ \frac{\partial{L}}{\partial{\boldsymbol{k_2}}} \\ \end{bmatrix}= \begin{bmatrix} q_{11} (e_1 - y_1) \ \ q_{21} (e_1 - y_1) \\ q_{11}(e_2 - y_2) \ \ q_{21} (e_2 - y_2) \\ \end{bmatrix} \end{aligned}\]- Update step:
The following julia code illustrates an example with 3x1 ks: $\boldsymbol{k_1},\boldsymbol{k_2},\boldsymbol{k_3}$, 3x4 $K$ matrix and a 4x3 $\boldsymbol{x}$ input vector.
# Squared Loss
function Squared_Loss(z,x)
return 0.5*(L2norm(z-x))^2
end
# L2 norm
function L2norm(x)
sqrt(sum(x.^2))
end
#Feed forward
function FeedForward(K,q,x,y)
k = K*x
e = q'*k
return k,e
end
# Forward propagate
function forwardprop(K,q,x,y)
k, e = FeedForward(K,q,x,y)
return backprop(K,k,x,q,e,y)
end
# Backpropagate
function backprop(K,k,x,q,e,y,η=.01)
println(Squared_Loss(e',y))
∂L_∂e = (e-y')'
∂e_∂k = q'
∂k_∂K = x[1:end-1,:]'
∂L_∂k = ∂L_∂e*∂e_∂k
∂L_∂K = ∂L_∂k'*∂k_∂K
∂L_∂Kb = ∂L_∂k'
#init K weights
K_new = K
#update step
K_new[:,1:end-1] = K_new[:,1:end-1] .- η * ∂L_∂K
K_new[:,end:end] = K_new[:,end:end] .- η * sum(∂L_∂Kb,dims=2)
return K_new
end
K = [1 -3 6. 0; 2 0 1 0; 4 5 1 0;] #Last column is bias terms initialized to 0
q = [1.; 5; .2;]
x = [0.5 .2 .4; 0.5 .2 .6; .3 .25 .7; 1 1 1;] #Last row are bias terms set to 1
y = [-3.14; -6.3; 2.21;]
for i=1:300
K = forwardprop(K,q,x,y)
end
println("K:",K)
println("e:",FeedForward(K,q,x,y)[2])
Softmax
The left diagram above shows an abstraction of the softmax function (blackbox). If we pass in 2 values $e_1,e_2$ through the ‘blackbox’ softmax function we get 2 values $\alpha_1, \alpha_2$ out. The right diagram shows the circuit-like representation (internals) of the softmax function (blackbox).
-
Forward step \(\begin{aligned} softmax \ function, \sigma{(x_i)}=\frac{e^{x_i}}{\sum_{i=1}^2 {e^{x_i}}} \\ for \ i=1,2 \end{aligned}\)
-
Backpropagate step
\(\begin{aligned} \frac{\partial{L}}{\partial{\boldsymbol{\alpha}}}&= \begin{bmatrix} \frac{\partial{L}}{\partial{\alpha_1}} \ \frac{\partial{L}}{\partial{\alpha_2}} \end{bmatrix}\\ \frac{\partial{L}}{\partial{e_i}}&= \frac{\partial{L}}{\partial{\boldsymbol{\alpha}}}\frac{\partial{\boldsymbol{\alpha}}}{\partial{e_i}}\\ \frac{\partial{\boldsymbol{\alpha}}}{\partial{e_i}}&= \begin{bmatrix} \frac{\partial{\alpha_1}}{\partial{e_i}} \\ \frac{\partial{\alpha_2}}{\partial{e_i}} \end{bmatrix}\\ \frac{\partial{L}}{\partial{e_1}}&= \frac{\partial{L}}{\partial{\alpha_1}}\frac{\partial{\alpha_1}}{\partial{e_1}}+\frac{\partial{L}}{\partial{\alpha_2}}\frac{\partial{\alpha_2}}{\partial{e_1}}\\ &=\frac{\partial{L}}{\partial{\alpha_1}}(\alpha_1)(1-\alpha_1)+\frac{\partial{L}}{\partial{\alpha_2}}(-\alpha_1\alpha_2)\\ \frac{\partial{L}}{\partial{e_2}}&=\frac{\partial{L}}{\partial{\alpha_1}}\frac{\partial{\alpha_1}}{\partial{e_2}}+\frac{\partial{L}}{\partial{\alpha_2}}\frac{\partial{\alpha_2}}{\partial{e_2}}\\ &=\frac{\partial{L}}{\partial{\alpha_1}}(-\alpha_1\alpha_2)+\frac{\partial{L}}{\partial{\alpha_1}}(\alpha_2)(1-\alpha_2)\\ \end{aligned}\)
To see this visually as signals flowing back, let $a=e^{e_1}, c=e^{e_1}, b=e^{e_1}+e^{e_2}, \alpha_1=\frac{a}{b},\alpha_2=\frac{c}{b}$. Next, let us rewrite the division operator as a product of $a$,$1/b$.
Refering to the above figure, we have red and blue arrows denoting signals from $\alpha_1$ and $\alpha_2$ respectively. Let us first zoom in on the $\alpha_1$ output. On the backward pass, the signal forks into 2 branches - top and mid branch.
On the top branch,
\(\begin{aligned}
\frac{\partial{\alpha_1}}{\partial{a}}&=\frac{1}{b}\\
\frac{\partial{a}}{\partial{e_1}}&=e^{e_1}\\
(\frac{\partial{\alpha_1}}{\partial{e_1}})_{top}&=\frac{e^{e_1}}{b}\\
\end{aligned}\)
On the mid branch,
\(\begin{aligned}
\frac{\partial{\alpha_1}}{\partial{b}}&=-\frac{a}{b^2}\\
\frac{\partial{b}}{\partial{e_1}}&=e^{e_1}\\
(\frac{\partial{\alpha_1}}{\partial{e_1}})_{mid}&=-\frac{e^{e_1}a}{b^2}\\
\end{aligned}\)
As both paths converge at the ‘exp’ node we add them together giving
\(\begin{aligned}
\frac{\partial{\alpha_1}}{\partial{e_1}}&=\frac{e^{e_1}}{b}-\frac{e^{e_1}a}{b^2}\\
&=\frac{e^{e_1}}{b}(1-\frac{a}{b})\\
&=\alpha_1(1-\alpha_1)
\end{aligned}\)
Now zooming in to the $\alpha_2$ output and working backwards, we see that we have a path to $e_1$ via the $b$ mid branch. On the mid branch, \(\begin{aligned} \frac{\partial{\alpha_2}}{\partial{b}}&=-\frac{c}{b^2}\\ \frac{\partial{b}}{\partial{e_1}}&=e^{e_1}\\ (\frac{\partial{\alpha_2}}{\partial{e_1}})_{mid}&=-\frac{e^{e_1}c}{b^2}\\ &=-\alpha_1\alpha_2 \end{aligned}\)
The same steps can be applied to obtain the gradients for $e_2$.