Diffusion Language Models
My notes taken from this Discrete Diffusion Modeling by Estimating the Ratios of the Data Distribution.
1
2
import torch
from matplotlib import pyplot as plt
Diffusion Language Models
Motivation
Goal is to learn:
\[\mathbb E_{x\sim p_{data}}[-\log p_\theta(x)] \approx \frac{1}{n} \sum_{i=1}^n -\log{p_\theta(x_i)}\]Which we can force a neural network to do by modeling
\[p_\theta(x) = \frac{e^{f_{\theta}(x)}}{Z}, Z = \sum_{x \in \mathcal X} e^{f_\theta(x)}\]But $Z$ is exponential (i.e. $50257^{1024}$ for a 1024 sequence from gpt2)
Autoregressive
Attempts the factorization
\[p_{data}(x^1 x^2 \dots x^d) = p_{data}(x^1) p_{data}(x^2 \mid x^1) \dots p_{data}(x^d \mid x^1 \dots x^{d - 1})\]Only needing to learn $p_\theta(x^i \mid x^1 \dots x^{i-1}) \approx p_{data}(x^i \mid x^1 \dots x^{i-1})$
which can generate new sequence data from ancestral sampling.
What does diffusion do?
Instead of modeling the probability directly, model the ratio:
\[\frac{p_\theta (y)}{p_\theta(x)} = \frac{e^{f_{\theta}(y)} / Z}{e^{f_{\theta}(x)} / Z} = \frac{e^{f_{\theta}(y)}}{e^{f_{\theta}(x)}}\]and somehow use these ratios to generate new sequences.
Diffusion Process
Forward diffusion process (linear ode)
\[\begin{align*} \frac{dp_t}{dt} = Q_t p_t \quad p_0 \approx p_{data} \end{align*}\]$Q$ is a transition matrix having columns sums of 0 (i.e. $\sum_j Q_{ij} = 0$) for example:
\[\begin{align*} Q_{uniform} = \begin{bmatrix} 1 - N & 1 & \dots & 1 \\ 1 & 1 - N & \dots & 1 \\ \vdots & \vdots & \ddots & \vdots \\ 1 & 1 & \dots & 1 - N \end{bmatrix} \end{align*}\]1
2
3
N = 3
Q = torch.ones(N, N) - torch.eye(N) * N
Q
1
2
3
tensor([[-2., 1., 1.],
[ 1., -2., 1.],
[ 1., 1., -2.]])
1
Q.sum(0)
1
tensor([0., 0., 0.])
1
p = torch.tensor([0.5, 0.3, 0.2])
1
2
dpdt = Q @ p
dpdt
1
tensor([-0.5000, 0.1000, 0.4000])
1
dpdt.sum()
1
tensor(-2.9802e-08)
Reversal Diffusion Process
\[\begin{align*} \frac{dp_{T - t}}{dt} = \bar Q_{T - t}p_{T - t} \\ \bar Q(y, x) = \frac{p_t(y)}{p_t(x)} Q_t(x, y) \\ \bar Q(x, x) = \sum_{x \neq y} \bar Q(y, x) \end{align*}\]Last equation is like detailed balance.
1
2
scores = p.view(1, -1) / p.view(-1, 1)
scores
1
2
3
tensor([[1.0000, 0.6000, 0.4000],
[1.6667, 1.0000, 0.6667],
[2.5000, 1.5000, 1.0000]])
1
2
3
4
Q_bar = (scores * Q).T
Q_bar = Q_bar - torch.diag(torch.diag(Q_bar))
Q_bar -= torch.diag(torch.sum(Q_bar, dim=0))
Q_bar
1
2
3
tensor([[-1.0000, 1.6667, 2.5000],
[ 0.6000, -2.3333, 1.5000],
[ 0.4000, 0.6667, -4.0000]])
Concrete Score
\[s(x)_y = \frac{p_t(y)}{p_t(x)}\]Related to score function for discrete probabilities $\nabla_x \log{p} = \frac{\nabla_x{p}}{p} \approx \frac{p(y) - p(x)}{p(x)} = \frac{p(y)}{p(x)} - 1 = s(x)_y - 1$
Score Entropy
\[\begin{align*} \mathcal L_{SE} = \mathbb E_{x \sim p}[\sum_{y \neq x} w_{xy} \big(s_\theta(x)_y - \frac{p(y)}{p(x)} \log{s_\theta(x)_y} + K (\frac{p(y)}{p(x)}) \big)] \end{align*}\]- $K(a) = a(\log{a} - 1)$ ensures $\mathcal L_{SE} >= 0$
- Built on bregman divergence $D_F(s(x)_y, \frac{p(y)}{p(x)})$, $F = -\log$
- gurantees non-negative, symmetric, and convex
- generalizes standard cross entropy from simplex value to positive values
1
2
3
4
5
6
7
8
9
10
11
noise = torch.rand_like(scores).unsqueeze(0)
noise_schedule = torch.arange(0, 10).view(-1, 1, 1) * 20
noise = noise * noise_schedule
scores_hat = scores + noise
mask = torch.ones_like(scores) - torch.eye(scores.shape[-1])
L_se = scores_hat - scores * torch.log(scores_hat)
K = scores * (torch.log(scores) - 1)
L_se += K
L_se *= mask
L_se = L_se.sum(dim=-1).mean(dim=-1)
torch.stack((noise_schedule.flatten(), L_se)).T.round(decimals=1)
1
2
3
4
5
6
7
8
9
10
tensor([[ 0.0000, 0.0000],
[ 20.0000, 12.9000],
[ 40.0000, 28.9000],
[ 60.0000, 45.6000],
[ 80.0000, 62.5000],
[100.0000, 79.5000],
[120.0000, 96.6000],
[140.0000, 113.8000],
[160.0000, 131.0000],
[180.0000, 148.2000]])
1
2
3
4
5
6
7
8
9
10
11
s = torch.logspace(-9, 1 / 1.5, steps=100)
true_ratio = 0.4
se_loss = s - true_ratio * torch.log(s)
plt.plot(s.numpy(), se_loss.numpy())
plt.axvline(x=true_ratio, color='r', linestyle='--', label='True Ratio')
plt.axhline(y=se_loss.min().item(), color='g', linestyle='--', label='Min Loss')
plt.legend()
plt.xlabel("s")
plt.ylabel("Score Entropy Loss")
plt.title("Score Entropy Loss")
plt.show()
Score entropy properties
- Consistency, $\mathcal L_{SE} = 0$ at $\theta^*$
- has log barrier that keeps $s_\theta \geq 0$
- can be made tractable by removing $\frac{p(y)}{p(x)}$ term
Implicit Score Entropy
Remove $\frac{p(y)}{p(x)}$ but keep proportional to $L_{SE}$:
\[\mathcal L_{ISE} = \mathbb E_{x \sim p} [\sum_{y \neq x} w_{xy} s_\theta(x)_y - w_{yx} \log{s_\theta(y)_x}]\]Still not tractable due to $s_\theta(y)x$ and $s\theta(x)_y$ in the same sum
1
2
3
4
L_ise = scores_hat - torch.log(scores_hat.transpose(-1, -2))
L_ise = L_ise * mask
L_ise = L_ise.sum(-1).mean(-1)
torch.stack([noise_schedule.flatten(), L_ise]).T.round(decimals=1)
1
2
3
4
5
6
7
8
9
10
tensor([[ 0.0000, 2.4000],
[ 20.0000, 15.5000],
[ 40.0000, 31.8000],
[ 60.0000, 48.5000],
[ 80.0000, 65.5000],
[100.0000, 82.6000],
[120.0000, 99.8000],
[140.0000, 117.1000],
[160.0000, 134.3000],
[180.0000, 151.7000]])
1
L_se - L_ise
1
2
tensor([-2.4444, -2.5890, -2.8114, -2.9603, -3.0719, -3.1611, -3.2354, -3.2991,
-3.3548, -3.4043])
Denoising Score Entropy
Remove $s_\theta(y)_x$ term.
Suppose $p = \sum_{x_0} p(x \mid x_0)p_0(x_0)$ is a perturbation of $p_0$ (some base density):
\[\mathcal L_{DSE} = \mathbb E_{\stackrel{x_0 \sim p_0}{x \sim p(\cdot \mid x_0)}} [\sum_{y \neq x} w_{xy} \big(s_\theta(x)_y - \frac{p(y \mid x_0)}{p(x \mid x_0)} \log{s_\theta(x)_y}\big)]\]- only involves $s_\theta(x)_y$
- $L_{DSE} \sim L_{SE}$ up to a constant
- $p_t$ are intermediate pertubations of $p_0$
Diffusion Weighted Denoising Score Entropy
Create an ELBO for $p^\theta_0 (x_0)$.
We can construct an ELBO for the data distribution $p_0$ using a weighted version of $\mathcal L_{DSE}$ with weights $w_{xy} = Q_t(x_t, y)$,
\[\begin{align*} \mathcal L_{DWDSE} &= \int_{0}^T \mathbb E_{x_t \sim p_{t\mid 0}(\cdot \mid x_0)} [ \sum_{y \neq x} Q_t(x_t, y) \bigg( s_{\theta} (x_t, t)_y - \\ & \frac{p_{t\mid 0}(y \mid x_0)}{p_{t\mid 0} (x_t \mid x_0)} \log{s_{\theta} (x_t, t)_y} + K \big(\frac{p_{t\mid 0}(y \mid x_0)}{p_{t\mid 0} (x_t \mid x_0)} \big) \bigg) dt ] \\ & = \int_{0}^T \mathbb E_{x_t \sim p_{t\mid 0}(\cdot \mid x_0)} \bigg(\mathcal L_{DSE} + K \big(\frac{p_{t\mid 0}(y \mid x_0)}{p_{t\mid 0} (x_t \mid x_0)} \big) \bigg) dt \end{align*}\]which forms the ELBO:
\[-\log{p^\theta_0 (x_0)} \leq \mathcal L_{DWDSE}(x_0) + D_{KL}(p_{T\mid 0}(\cdot \mid x_0) \mid\mid p_{base})\]Practical Implementation
Since $\mathcal X = {1, \dots, n }$ forming sequences $x = x^1 \dots x^d$, $Q_t$ would be exponential in size. Instead, change one token at a time (hamming distance 1),
\[Q^{tok}_t (x^i, \hat x^i) = Q_t(x^1 \dots x^i \dots x^d, x^1 \dots \hat x^i \dots x^d) \\ (s_\theta(x^1 \dots x^i \dots x^d))_{i, \hat x^i} \approx \frac{p_t(x^1 \dots \hat x^i \dots x^d)}{p_t(x^1 \dots x^i \dots x^d)}\]leading towards the factorization
\[p^{seq}_{t\mid 0}(\hat x, x) = \prod_{i=1}^d p^{tok}_{t\mid 0}(\hat x^i \mid x^i)\]Using $Q_t^{tok} = \sigma(t) Q^{tok}$ and cumulative noise $\bar \sigma(t) = \int_{0}^{t} \sigma(s)ds$, we have:
\[p^{tok}_{t\mid 0}(\cdot \mid x) = \text{x-th column of } \exp(\bar \sigma(t) Q^{tok})\]$\bar \sigma(t)$ can be the log linear noise schedule $\bar \sigma(t) = -\log{1 - (1 - \epsilon T)}$ or geometric $\bar \sigma(t) = \sigma_{min}^{t - 1}\sigma_{max}^t$
1
2
noise = lambda t, sigma: torch.pow(sigma, t)
noise(torch.arange(0, 1, 0.1), 1e-2)
1
2
tensor([1.0000, 0.6310, 0.3981, 0.2512, 0.1585, 0.1000, 0.0631, 0.0398, 0.0251,
0.0158])
Training Algorithm
Sample $x_0 \sim p_0$, $t \sim \mathcal U([0, T])$
Construct $x_t$ from $x_0$, $x_t^i \sim p_{t \mid 0}(\cdot \mid x_0^i) = \exp{(\bar \sigma (t) Q)_{x_0^i}}$
Compute and backpropogate:
\[\hat{\mathcal{L}}_{DWDSE} = \sigma(t) \sum_{i=1}^d \sum_{y=1}^n (1 - \delta_{x_t^i}(y)) \big( s_\theta (x_t, t)_{i, y} - \frac{p_{t \mid 0}(y \mid x_0^i)}{p_{t \mid 0}(x_t^i \mid x_0^i)} \log{s_\theta (x_t, t)_{i, y}} \big)\]Sampling
Tau Leaping
\[x_{t - \Delta t}^i \sim \delta_{x_t^i}(x_{t - \Delta t}^i) + \Delta t Q_t^{tok} (x_t^i, x_{t - \Delta t}^i) s_\theta(x_t, t)_{i, x_{t - \Delta t}^i}\]Tweedie $\tau$-leaping
\[p_{t - \Delta t \mid t}(x_{t - \Delta t} \mid x_t)^{\text{tweedie}} = (\exp{(-\sigma_t^{\Delta t} Q s_\theta (x_t, t)_i)})_{x_{t - \Delta t}^i} \exp{(\sigma_t^{\Delta t} Q)(x_t^i, x_{t - \Delta t}^i)}\]Prompting & Infilling
$p_t(x^\Omega \mid x^{\bar \Omega} = y)$ can be recovered exactly:
\[\begin{align*} \frac{p_t(x^\Omega =z' \mid x^{\bar \Omega} = y)}{p_t(x^\Omega = z \mid x^{\bar \Omega} = y)} &= \frac{p_t(x^\Omega =z' \cap x^{\bar \Omega} = y) / p_t(x^{\bar \Omega} = y)}{p_t(x^\Omega =z \cap x^{\bar \Omega} = y) / p_t(x^{\bar \Omega} = y)} \\ &= \frac{p_t(x^\Omega =z' \cap x^{\bar \Omega} = y)}{p_t(x^\Omega =z \cap x^{\bar \Omega} = y)} \end{align*}\]