Natural Gradient Descent
Natural gradient descent is an optimization method that accounts for the geometry of the parameter space when updating model parameters. Unlike standard gradient descent, which treats parameter space as Euclidean, natural gradient descent uses the Fisher information matrix to define a Riemannian metric on the space of probability distributions.
Standard Gradient Descent
In standard gradient descent, we update parameters $\theta$ by:
\[\theta_{t+1} = \theta_t - \eta \nabla_\theta \mathcal{L}(\theta_t)\]This assumes that the parameter space is flat (Euclidean), meaning a small step $\delta\theta$ has equal effect regardless of where we are in parameter space. However, for probabilistic models, this assumption is often inappropriate.
Standard Gradient Descent Example
Consider a simple 2D optimization problem where we want to estimate the mean $\mu = [\theta_1, \theta_2]$ of a Gaussian distribution with a fixed, diagonal covariance matrix $\Sigma = \text{diag}(1, 10)$. The loss function is the negative log-likelihood, which simplifies to a weighted squared error:
\[\mathcal{L}(\theta) = \frac{1}{2} (\theta_1^2 + 10\theta_2^2)\]The contours of this loss function form ellipses. In standard gradient descent, the update steps are perpendicular to these contours, which does not necessarily point towards the minimum.
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import torch
import plotly.graph_objects as go
import numpy as np
from IPython.display import HTML
# Define the loss function (negative log-likelihood for Gaussian)
# L(theta) = 0.5 * (theta_1^2 + 10 * theta_2^2)
def loss_fn(theta):
return 0.5 * (theta[0] ** 2 + 10 * theta[1] ** 2)
# Visualization setup
theta_1 = np.linspace(-3, 3, 100)
theta_2 = np.linspace(-2.25, 2.25, 100)
T1, T2 = np.meshgrid(theta_1, theta_2)
Z = 0.5 * (T1**2 + 10 * T2**2)
# Standard Gradient Descent
theta = torch.tensor([-2.5, 0.8], requires_grad=True)
lr = 0.18 # Large enough to overshoot in the high-curvature direction
path_sgd = [theta.detach().numpy().copy()]
for _ in range(30):
loss = loss_fn(theta)
loss.backward()
with torch.no_grad():
theta -= lr * theta.grad
theta.grad.zero_()
path_sgd.append(theta.detach().numpy().copy())
# Plotting the trace
path_sgd = np.array(path_sgd)
fig = go.Figure()
# Add contour plot
fig.add_trace(
go.Contour(
x=theta_1,
y=theta_2,
z=Z,
colorscale="Viridis",
showscale=False,
ncontours=20,
contours=dict(showlabels=False),
)
)
# Add gradient descent path
fig.add_trace(
go.Scatter(
x=path_sgd[:, 0],
y=path_sgd[:, 1],
mode="lines+markers",
name="Standard GD",
line=dict(color="red"),
marker=dict(size=6, color="red"),
)
)
fig.update_layout(
title="Standard Gradient Descent",
xaxis_title="Theta 1",
yaxis_title="Theta 2",
yaxis=dict(scaleanchor="x", scaleratio=1),
)
# Save as HTML and display
html_output = fig.to_html(include_plotlyjs="cdn", full_html=False)
display(HTML(html_output))
As seen in the plot generated by the code above, the standard gradient descent takes a zig-zag path or converges slowly because it is sensitive to the scaling of the parameters. It does not account for the fact that $\theta_2$ has a much higher curvature (sharper valley) than $\theta_1$.
The Fisher Information Matrix
The Fisher information matrix $\mathbf{F}$ captures the local curvature of the KL divergence between nearby distributions:
\[\mathbf{F}(\theta) = \mathbb{E}_{p(x\mid\theta)} \left[ \nabla_\theta \log p(x\mid\theta) \, \nabla_\theta \log p(x\mid\theta)^\top \right]\]It defines a Riemannian metric on the statistical manifold of parameterized distributions, making it the natural notion of “distance” between nearby distributions.
Bartlett’s Identity and the Information Matrix
Deriving Bartlett’s Identity will help us connect the Fisher Information Matrix to distance measuring. We start with the fact that the probability density function must integrate to 1:
\[\int p(x\mid\theta) \, dx = 1\]Differentiating both sides with respect to $\theta$:
\[\nabla_\theta \int p(x\mid\theta) \, dx = \int \nabla_\theta p(x\mid\theta) \, dx = \int p(x\mid\theta) \nabla_\theta \log p(x\mid\theta) \, dx = \mathbb{E}_{x \sim p(x\mid\theta)}[\nabla_\theta \log p(x\mid\theta)] = 0\]This is the first identity (mean score is zero). Differentiating again with respect to $\theta^\top$:
\[\nabla_\theta \left( \int p(x\mid\theta) \nabla_\theta \log p(x\mid\theta)^\top \, dx \right) = 0\]Using the product rule inside the integral:
\[\int \left( \nabla_\theta p(x\mid\theta) \nabla_\theta \log p(x\mid\theta)^\top + p(x\mid\theta) \nabla^2_\theta \log p(x\mid\theta) \right) \, dx = 0\]Substitute $\nabla_\theta p(x\mid\theta) = p(x\mid\theta) \nabla_\theta \log p(x\mid\theta)$:
\[\int p(x\mid\theta) \left( \nabla_\theta \log p(x\mid\theta) \nabla_\theta \log p(x\mid\theta)^\top + \nabla^2_\theta \log p(x\mid\theta) \right) \, dx = 0\]This implies:
\[\mathbb{E}_{x \sim p(x\mid\theta)} \left[ \nabla_\theta \log p(x\mid\theta) \nabla_\theta \log p(x\mid\theta)^\top \right] + \mathbb{E}_{x \sim p(x\mid\theta)} \left[ \nabla^2_\theta \log p(x\mid\theta) \right] = 0\]Thus, we obtain Bartlett’s second identity:
\[\mathbf{F}(\theta) = -\mathbb{E}_{x \sim p(x\mid\theta)} \left[ \nabla^2_\theta \log p(x\mid\theta) \right]\]Fisher Information and Local Distance
The Fisher Information Matrix defines a statistical distance—a Riemannian metric—over the space of probability distributions. To see this, consider the Kullback-Leibler (KL) divergence between a distribution parameterized by $\theta$ and a slightly perturbed distribution parameterized by $\theta + \delta\theta$:
\[D_{KL}(p(x\mid\theta) \,\|\, p(x\mid\theta + \delta\theta)) = \mathbb{E}_{x \sim p(x\mid\theta)} \left[ \log p(x\mid\theta) - \log p(x\mid\theta + \delta\theta) \right]\]We can use a Taylor series expansion of $\log p(x\mid\theta + \delta\theta)$ around $\theta$:
\[\log p(x\mid\theta + \delta\theta) \approx \log p(x\mid\theta) + \nabla_\theta \log p(x\mid\theta)^\top \delta\theta + \frac{1}{2} \delta\theta^\top \nabla^2_\theta \log p(x\mid\theta) \delta\theta\]Substituting this back into the KL divergence expression:
\[\begin{aligned} D_{KL}(p(x\mid\theta) \,\|\, p(x\mid\theta + \delta\theta)) &\approx \mathbb{E}_{x \sim p(x\mid\theta)} \left[ -\nabla_\theta \log p(x\mid\theta)^\top \delta\theta - \frac{1}{2} \delta\theta^\top \nabla^2_\theta \log p(x\mid\theta) \delta\theta \right] \\ &= -\underbrace{\mathbb{E} \left[ \nabla_\theta \log p(x\mid\theta) \right]^\top}_{=0} \delta\theta - \frac{1}{2} \delta\theta^\top \mathbb{E} \left[ \nabla^2_\theta \log p(x\mid\theta) \right] \delta\theta \\ &= -\frac{1}{2} \delta\theta^\top \mathbb{E} \left[ \nabla^2_\theta \log p(x\mid\theta) \right] \delta\theta \\ &= \frac{1}{2} \delta\theta^\top \mathbf{F}(\theta) \delta\theta \end{aligned}\]This shows that the Fisher Information Matrix $\mathbf{F}(\theta)$ acts as the metric tensor that measures the local curvature of the KL divergence manifold.
Notice that the expected score $\mathbb{E} [\nabla_\theta \log p(x\mid\theta)]$ is zero (Bartlett’s first identity). The remaining term involves the expected Hessian of the log-likelihood, which is the Fisher Information Matrix (Bartlett’s second identity).
Symmetry of Local KL Divergence
A common question is whether it matters which direction of the KL divergence we minimize: $D_{KL}(p_\theta | p_{\theta+\delta\theta})$ (Forward KL) or $D_{KL}(p_{\theta+\delta\theta} | p_\theta)$ (Reverse KL).
While the KL divergence is generally asymmetric (i.e., $D_{KL}(P|Q) \neq D_{KL}(Q|P)$), for infinitesimally small steps $\delta \theta$, it creates a locally symmetric approximation.
Let’s expand the Reverse KL divergence $D_{KL}(p_{\theta+\delta\theta} | p_\theta)$:
\[D_{KL}(p_{\theta+\delta\theta} \| p_\theta) = \mathbb{E}_{x \sim p_{\theta+\delta\theta}} \left[ \log p(x \mid \theta+\delta\theta) - \log p(x \mid \theta) \right]\]We use a Taylor expansion of the log-likelihood ratio around $\theta$:
\[\log p(x \mid \theta+\delta\theta) - \log p(x \mid \theta) \approx \nabla_\theta \log p(x \mid \theta)^\top \delta\theta + \frac{1}{2}\delta\theta^\top \nabla^2_\theta \log p(x \mid \theta) \delta\theta\]To evaluate the expectation over the perturbed distribution $p_{\theta+\delta\theta}$, we can rewrite the density:
\[\begin{aligned} p(x \mid \theta+\delta\theta) &\approx p(x \mid \theta) + \nabla_\theta p(x \mid \theta)^\top \delta\theta \\ &= p(x \mid \theta) + p(x \mid \theta)\nabla_\theta \log p(x \mid \theta)^\top \delta\theta \\ &= p(x \mid \theta) \left( 1 + \nabla_\theta \log p(x \mid \theta)^\top \delta\theta \right) \end{aligned}\]This allows us to convert the expectation over $p_{\theta+\delta\theta}$ into an expectation over $p_\theta$:
\[\mathbb{E}_{p_{\theta+\delta\theta}}[f(x)] \approx \mathbb{E}_{p_\theta} \left[ f(x) \left( 1 + \nabla_\theta \log p(x \mid \theta)^\top \delta\theta \right) \right]\]Substituting $f(x)$ with the log-likelihood ratio expansion (using the abbreviation: $\nabla \log p = \nabla_\theta \log p(x \mid \theta)$):
\[\begin{aligned} \mathbb{E}_{p_{\theta+\delta\theta}} \left[ \log \frac{p_{\theta+\delta\theta}}{p_\theta} \right] &\approx \mathbb{E}_{p_\theta} \left[ \underbrace{\left(1 + \nabla \log p^\top \delta\theta\right)}_{\text{density ratio } \frac{p_{\theta+\delta\theta}}{p_\theta}} \underbrace{\left(\nabla \log p^\top \delta\theta + \frac{1}{2} \delta\theta^\top \nabla^2 \log p \delta\theta\right)}_{\text{log-ratio expansion}} \right] \\ &= \mathbb{E}_{p_\theta} \left[ \underbrace{1 \cdot (\nabla \log p^\top \delta\theta)}_{\text{Order 1}} + \underbrace{1 \cdot (\frac{1}{2} \delta\theta^\top \nabla^2 \log p \delta\theta)}_{\text{Order 2 (Hessian)}} + \underbrace{(\nabla \log p^\top \delta\theta) \cdot (\nabla \log p^\top \delta\theta)}_{\text{Order 2 (Cross term)}} + O(\delta\theta^3) \right] \\ &= \mathbb{E}_{p_\theta} \left[ \nabla \log p^\top \delta\theta \right] + \frac{1}{2} \delta\theta^\top \mathbb{E}_{p_\theta} [\nabla^2 \log p] \delta\theta + \delta\theta^\top \mathbb{E}_{p_\theta} [(\nabla \log p)(\nabla \log p)^\top] \delta\theta \end{aligned}\]- First Term: $\mathbb{E}[\nabla \log p]^\top \delta\theta = 0$ (Bartlett’s first identity).
- Second Term: $\mathbb{E}[\nabla^2 \log p] = -\mathbf{F}$ (Bartlett’s second identity).
- Third Term: $\mathbb{E}[(\nabla \log p)(\nabla \log p)^\top] = \mathbf{F}$ (Definition of Fisher Information).
Combining these:
\[\begin{aligned} D_{KL}(p_{\theta+\delta\theta} \| p_\theta) &\approx 0 - \frac{1}{2} \delta\theta^\top \mathbf{F} \delta\theta + \delta\theta^\top \mathbf{F} \delta\theta \\ &= \frac{1}{2} \delta\theta^\top \mathbf{F}(\theta) \delta\theta \end{aligned}\]This matches the result for the Forward KL divergence derived earlier. Thus, locally, the KL divergence behaves like a squared distance metric (a quadratic form weighted by $\mathbf{F}$), proving that the direction of minimization does not change the resulting Natural Gradient update.
The Natural Gradient Update
The natural gradient descent update replaces the ordinary gradient with the natural gradient:
\[\theta_{t+1} = \theta_t - \eta \, \mathbf{F}(\theta_t)^{-1} \nabla_\theta \mathcal{L}(\theta_t)\]This corresponds to steepest descent with respect to the KL divergence rather than Euclidean distance in parameter space.
Natural Gradient Descent Example
Let’s revisit the previous example with the loss function $\mathcal{L}(\theta) = \frac{1}{2}(\theta_1^2 + 10\theta_2^2)$. Here, the Fisher Information Matrix corresponds exactly to the Hessian of the loss function, which is constant:
\[\mathbf{F}(\theta) = \begin{bmatrix} 1 & 0 \\ 0 & 10 \end{bmatrix}\]The inverse Fisher Information Matrix is:
\[\mathbf{F}^{-1} = \begin{bmatrix} 1 & 0 \\ 0 & 0.1 \end{bmatrix}\]When we apply the natural gradient update, we precondition the standard gradient with this inverse matrix.
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
# Fisher Information Matrix (Hessian of the loss)
F = torch.tensor([[1.0, 0.0], [0.0, 10.0]])
F_inv = torch.inverse(F)
# Natural Gradient Descent
theta = torch.tensor([-2.5, 0.8], requires_grad=True)
lr = 0.2 # Step size
path_ngd = [theta.detach().numpy().copy()]
for _ in range(10): # Creates a direct path
loss = loss_fn(theta)
loss.backward()
with torch.no_grad():
# Natural Gradient: F_inv * grad
natural_grad = torch.mv(F_inv, theta.grad)
theta -= lr * natural_grad
theta.grad.zero_()
path_ngd.append(theta.detach().numpy().copy())
path_ngd = np.array(path_ngd)
fig = go.Figure()
# Add contour plot
fig.add_trace(
go.Contour(
x=theta_1,
y=theta_2,
z=Z,
colorscale="Viridis",
showscale=False,
ncontours=20,
contours=dict(showlabels=False),
)
)
# Add natural gradient descent path
fig.add_trace(
go.Scatter(
x=path_ngd[:, 0],
y=path_ngd[:, 1],
mode="lines+markers",
name="Natural Gradient Descent",
line=dict(color="green"),
marker=dict(size=6, color="green"),
)
)
fig.update_layout(
title="Natural Gradient Descent",
xaxis_title="Theta 1",
yaxis_title="Theta 2",
yaxis=dict(scaleanchor="x", scaleratio=1),
)
# Save as HTML and display
html_output = fig.to_html(include_plotlyjs="cdn", full_html=False)
display(HTML(html_output))
Notice how the natural gradient descent path moves directly towards the minimum, correcting for the curvature of the loss surface. By accounting for the geometry (the “elliptical” nature of the contours), the steps are rescaled in each direction appropriately.
Standard vs. Natural Gradient Descent
To understand why the natural gradient descent path is so much more effective, as seen in the plot above, we need to dive deeper into the geometry of the update steps.
Tensor Equation Primer
In differential geometry, we distinguish between two types of vectors based on how their components transform when we change coordinates:
- Contravariant vectors (tangent vectors), denoted with upper indices $v^i$. These represent directions, rates of change, or displacements (e.g., velocity, parameter updates).
- Covariant vectors (covectors or dual vectors), denoted with lower indices $u_i$. These represent linear functionals that map tangent vectors to real numbers (e.g., gradients, forces).
Here are some common examples of each from physics and statistics:
| Concept | Contravariant Vector ($v^i$) | Covariant Vector ($u_i$) |
|---|---|---|
| Physics | Velocity ($dx^i/dt$) Displacement ($dx^i$) Electric Field (if defined as force per charge) | Gradient ($\nabla \phi$) Force ($F_i$ - typically) Momentum ($p_i$) |
| Statistics / ML | Parameter Update ($\Delta \theta^i$) Model Parameters ($\theta^i$) | Gradient of Loss ($\nabla_\theta \mathcal{L}$) Score Function ($\nabla_\theta \log p(x \mid \theta)$) |
| Transformation Rule | Components transform opposite to basis change | Components transform with basis change |
In introductory physics, we often see $p = mv$, connecting velocity (contravariant) and momentum (covariant). This equality holds in Euclidean space because the metric is trivial (often the identity). In general settings (like Lagrangian mechanics or General Relativity), they are related by the metric tensor (or mass matrix) which “lowers” the index: $p_i = \sum_j m_{ij} v^j$.
Validity of Tensor Equations
For an equation to be physically meaningful or geometrically valid, it must be coordinate invariant. This means the relationship should hold regardless of the coordinate system chosen. In tensor notation, this requires that the indices on both sides of an equation match in position (all upper or all lower).
If we try to add or subtract a covariant vector from a contravariant vector:
\[v^i - u_i\]This operation is undefined. Why? Consider a simple coordinate scaling where the new coordinates are $\tilde{x} = 2x$.
- The contravariant vector scales inversely: $\tilde{v} = v / 2$.
- The covariant vector scales directly: $\tilde{u} = u \cdot 2$.
The sum in the new coordinates would be $v/2 - 2u$, which is not simply a scaled version of the original sum $v-u$. The equation “breaks” because the terms scale differently, making the result dependent on the arbitrary choice of coordinates.
Formally, a valid tensor equation must satisfy the general transformation law. If $T$ is a tensor of a certain type, its components in the new coordinate system $T’$ must relate to the components in the old system $T$ via the Jacobian matrix $\frac{\partial x’}{\partial x}$ (or its inverse):
\[T'^{i \dots }_{j \dots} = \frac{\partial x'^i}{\partial x^k} \dots \frac{\partial x^l}{\partial x'^j} \dots T^{k \dots}_{l \dots}\]When indices do not match (e.g., adding a vector $v^i$ and a covector $u_i$), the resulting object does not transform according to any such consistent law, meaning it is not a tensor and has no coordinate-independent geometric meaning.
Why Standard Gradient Descent Fails on Manifolds
Standard gradient descent treats the gradient as a vector in the parameter space, but formally, it resides in the dual space.
The parameter update $\Delta \theta$ is a contravariant vector:
\[(\Delta \theta)^i\]The gradient of the loss function is a covariant vector:
\[(\nabla \mathcal{L})_i = \frac{\partial \mathcal{L}}{\partial \theta^i}\]The standard gradient descent update rule attempts to subtract a covariant vector from a contravariant vector:
\[\theta^{i}_{new} = \theta^{i}_{old} - \eta \frac{\partial \mathcal{L}}{\partial \theta^i}\]The indices do not match. We have an upper index on the left and a mix of upper and lower indices on the right. This operation is not well-defined on a general manifold because it is not invariant to coordinate transformations. It only “works” in Euclidean space because the metric tensor is the identity matrix, allowing us to loosely equate upper and lower indices.
To perform a valid update, we must “raise the index” of the gradient to transform it into a contravariant vector. This is done using the inverse metric tensor $G^{ij}$ (which in this case is the inverse Fisher Information Matrix $\mathbf{F}^{-1}$):
\[(\tilde{\nabla} \mathcal{L})^i = \sum_j G^{ij} (\nabla \mathcal{L})_j\]Now both the update direction and the parameter vector have upper indices, making the update geometrically consistent:
\[\theta^{i}_{new} = \theta^{i}_{old} - \eta \sum_j G^{ij} \frac{\partial \mathcal{L}}{\partial \theta^j}\]This is exactly the definition of the Natural Gradient update.
Why Natural Gradient Works
By premultiplying the gradient by the inverse Fisher Information Matrix $\mathbf{F}^{-1}$, natural gradient descent effectively transforms the covariant gradient into a contravariant, Riemannian gradient:
\[\tilde{\nabla} \mathcal{L} = \mathbf{F}^{-1} \nabla \mathcal{L}\]This vector points in the direction of steepest descent with respect to the KL divergence (the intrinsic distance on the manifold), rather than Euclidean distance. This accounts for the local geometry of the probability distributions, correcting for cases where a small change in parameters leads to a large change in the distribution (high curvature/Fisher information).
This approach is closely related to Newton’s Method, a second-order optimization technique that updates parameters using the inverse Hessian $\mathbf{H}^{-1}$:
\[\theta_{t+1} = \theta_t - \eta \mathbf{H}^{-1} \nabla \mathcal{L}\]In many log-likelihood problems, the Fisher Information Matrix approaches the Hessian of the loss function (or they are identical in expectation, via Bartlett’s identity). Therefore, Natural Gradient Descent can be viewed as a second-order method that uses the Fisher Information as a curvature proxy for the Hessian, often providing faster convergence than first-order methods without the instability or computational cost of the exact Hessian.
Connection to Adam and Adaptive Methods
Adaptive optimization algorithms like Adam and RMSProp can be interpreted as approximating the Natural Gradient update using a diagonal approximation of the Fisher Information Matrix.
Recall that the Fisher Information Matrix is the expected outer product of the gradients:
\[\mathbf{F} = \mathbb{E}[\nabla \log p(x\mid\theta) \nabla \log p(x\mid\theta)^\top]\]If we approximate this expectation with an exponential moving average of squared gradients (as done in RMSProp and Adam), and further constrain $\mathbf{F}$ to be a diagonal matrix, we get:
\[\mathbf{F}_{\text{diag}} \approx \text{diag}(\mathbb{E}[(\nabla_\theta \mathcal{L})^2])\]The natural gradient update with this diagonal approximation would be:
\[\Delta \theta \propto -\mathbf{F}_{\text{diag}}^{-1} \nabla_\theta \mathcal{L} = -\frac{\nabla_\theta \mathcal{L}}{\mathbb{E}[(\nabla_\theta \mathcal{L})^2]}\]However, Adam uses the square root of this diagonal approximation in the denominator:
\[\theta_{t+1} = \theta_t - \frac{\eta}{\sqrt{v_t} + \epsilon} m_t\]where $v_t$ is the exponential moving average of squared gradients (approximating $\text{diag}(\mathbf{F})$) and $m_t$ is the momentum.
Thus, Adam can be viewed as a “softened” diagonal Natural Gradient method. By using the square root, Adam scales updates based on the magnitude (standard deviation) of the gradients rather than the variance, which provides better numerical stability and scaling properties in practice.
References
- Amari, S. (1998). Natural Gradient Works Efficiently in Learning. Neural Computation.