Skip to content

Introduction to AMP and the Replica Trick

January 26, 2019
by

(This post from the lecture by Yueqi Sheng)

In this post, we will talk about detecting phase transitions using
Approximate-Message-Passing (AMP), which is an extension of
Belief-Propagation to “dense” models. We will also discuss the Replica
Symmetric trick, which is a heuristic method of analyzing phase
transitions. We focus on the Rademacher spiked Wigner model (defined
below), and show how both these methods yield the same phrase transition
in this setting.

The Rademacher spiked Wigner model (RSW) is the following. We are given
observations Y = \frac{\lambda}{n}xx^T + \frac{1}{\sqrt{n}}W where
x \in \{\pm 1\}^n (sampled uniformly) is the true signal and W is a
Gaussian-Orthogonal-Ensemble (GOE) matrix:
W_{i, j} \sim \mathbb{N}(0, 1) for i \neq j and
W_{i, i} \sim \mathbb{N}(0, 2). Here \lambda is the signal to noise
ratio. The goal is to approximately recover x.

The question here is: how small can \lambda be such that it is
impossible to recover anything reasonably correlated with the
ground-truth x? And what do the approximate-message-passing algorithm
(or the replica method) have to say about this?

To answer the first question, one can think of the task here is to
distinguish Y \sim \frac{\lambda}{n}xx^T + \frac{1}{\sqrt{n}}W vs
Y \sim W. One approach to distinguishing these distributions is to
look at the spectrum of the observation matrix Y. (In fact, it turns
out that this is an asymptotically optimal distinguisher [1]). The spectrum of Y behaves as ([2]):

  • When \lambda \leq 1, the empirical distribution of eigenvalues in
    spiked model still follows the semicircle law, with the top
    eigenvalues \approx 2

  • When \lambda > 1, we start to see an eigenvalue > 2 in the
    planted model.

Approximate message passing

This section approximately follows the exposition in [3].

First, note that in the Rademacher spiked Wigner model, the posterior
distribution of the signal \sigma conditioned on the observation Y
is: \Pr[\sigma | Y] \propto \Pr[Y | \sigma] \propto \prod_{i \neq j} \exp(\lambda Y_{i, j} \sigma_i \sigma_j /2 ) This
defines a graphical-model (or “factor-graph”), over which we can perform
Belief-Propogation to infer the posterior distribution of \sigma.
However, in this case the factor-graph is dense (the distribution is a
product of potentials \exp(\lambda Y_{i, j} \sigma_i\sigma_j) for all
pairs of i, j).

In the previous blog post, we saw belief propagation works great when the underlying interaction
graph is sparse. Intuitively, this is because G is locally tree like,
which allows us to assume each messages are independent random
variables. In dense model, this no longer holds. One can think of dense
model as each node receive a weak signal from all its neighbors.

In the dense model setting, a class of algorithms called Approximate
message passing (AMP) is proposed as an alternative of BP. We will
define AMP for RWM in terms of its state evolution.

State evolution of AMP for Rademacher spiked Wigner model

Recall that in BP, we wish to infer the posterior distributon of
\sigma, and the messages we pass between nodes correspond to marginal
probability distribution over values on nodes. In our setting, since the
distributions are over \{\pm 1\}, we can represent distributions by
their expected values. Let m^t_{u \to v} \in [-1, 1] denote the
message from u to v at time t. That is, m_{u \to v} corresponds
to the expected value {{\mathbb{E}}}[\sigma_u].

To derive the BP update rules, we want to compute the expectation
{{\mathbb{E}}}[\sigma_v] of a node v, given the
messages {{\mathbb{E}}}[\sigma_u] for u \neq v. We can
do this using the posterior distribution of the RWM, \Pr[\sigma | Y],
which we computed above.
\displaystyle \Pr[\sigma_v = 1 | Y, \{\sigma_u\}_{u \neq v}] = \frac{ \prod_u \exp(\lambda Y_{u, v} \sigma_u) - \prod_u \exp(-\lambda Y_{u, v} \sigma_u) }{ \prod_u \exp(\lambda Y_{u, v} \sigma_u) + \prod_u \exp(-\lambda Y_{u, v} \sigma_u) }

And similarly for \Pr[\sigma_v = -1 | Y, \{\sigma_u\}_{u \neq v}].
From the above, we can take expectations over \sigma_u, and express
{{\mathbb{E}}}[\sigma_v] in terms of
\{{{\mathbb{E}}}[\sigma_u]\}_{u \neq v}. Doing this (and
using the heuristic assumption that the distribution of \sigma is a
product distribution), we find that the BP state update can be written
as:
m^{t}_{u \to v} = f(\sum_{w \neq v}f^{-1}(A_{w, u} m^{t - 1}_{w \to u}))
where the interaction matrix A_{w, u} = \lambda Y_{w, u}, and
f(x) = tanh(x) = \frac{\exp(x) - \exp(-x)}{\exp(x) + \exp(x)}.

Now, Taylor expanding f^{-1} around 0, we find
m^{t}_{u \to v} = f\left( (\sum_{w \neq v} A_{w, u} m^{t - 1}_{w \to u}) + O(1/\sqrt{n}) \right)
since the terms A_{w, u} are of order O(1/\sqrt{n}).

At this point, we could try dropping the “non-backtracking” condition
w \neq v from the above sum (since the node v contributes at most
O(1/\sqrt{n}) to the sum anyway), to get the state update:
m^{t}_{u} = f\left( \sum_{w} A_{w, u} m^{t - 1}_{w}) \right) (note the messages no longer
depend on receiver – so we write m_u in place of m_{u \to v}).
However, this simplification turns out not to work for estimating the
signal. The problem is that the “backtracking” terms which we added
amplify over two iterations.

In AMP, we simply perform the above procedure, except we add a
correction term to account for the backtracking issue above. Given u,
for all v, the AMP update is:
m^{t}_{u \to v} = m^{t}_u = f(\sum_{w}A_{w, u} m^{t - 1}_{w}) + [\text{some correction term}]

The correction term corresponds to error introduced by the backtracking
terms. Suppose everything is good until step t - 2. We will examine
the influence of backtracking term to a node v through length 2 loops.
At time t - 1, v exert Y_{v, u}m^{t - 2}_v additional influence to
each of it’s neighbor u. At time t, v receive roughly
Y_{u, v}^2m^{t - 2}_v. Since Y_{u, v}^2 has magnitude
\approx \frac{1}{n} and we need to sum over all of v’s neighbors,
this error term is to large to ignore. To characterize the exact form of
correction, we simply do a taylor expansion

m^{t}_v = \sum_{u}f(Y_{u, v}m^{t - 1}_u) = \sum_{u}f(Y_{u, v} \left(\sum_{w}f(Y_{w, u}m^{t - 2}_w) - f(Y_{u, v}m^{t - 2}_w)\right) )\\ \approx \sum_u f(Y_{u, v} m^{t - 1}_u) - Y_{u, v}f'(m^{t - 1}_u)m^{t - 2}_v\\ \approx \sum_u f(Y_{u, v} m^{t - 1}_u) - \frac{1}{n}\sum_{u}f'(m^{t - 1}_u)m^{t - 2}_v

State evolution of AMP

In this section we attempt to obtain the phase transition of Rademacher
spiked Wigner model via looking at m^{\infty}.

We assume that each message could be written as a sum of signal term and
noise term. m^t = \mu_t x + \sigma_t g where
g \sim \mathbb{N}(0, I). To the dynamics of AMP (and find its phase
transition), we need to look at how the signal \mu_t and noise
\sigma_t evolves with t.

We do the following simplification: ignore the correction term and
assume each time we obtain an independent noise g.

m^{t} = Yf(m^{t - 1}) = (\frac{\lambda}{n}x^Tx + \frac{1}{\sqrt{n}}W)f(m^{t - 1}) = \frac{\lambda}{n} < f(m^{t - 1}), x > x + \frac{1}{\sqrt{n}} Wf(m^{t - 1})

Here, we see that \mu_t = \frac{\lambda}{n}< f(m^{t - 1}), x>
and \sigma_t = \frac{1}{\sqrt{n}}Wf(m^{t - 1}).

Note that \mu_{t} is essentially proportional to overlap between
ground truth and current belief
, since the function f keeps the
magnitude of the current beliefs bounded.

\frac{\lambda}{n} <f(m^{t - 1}), x>= \frac{\lambda}{n} <f(\mu_{t - 1}x + \sigma_{t - 1}g), x> \approx\lambda {{\mathbb{E}}}_{X \sim unif(\pm 1), G\sim \mathbb{N}(0, 1)}[X f(\mu_{t - 1}X + \sigma_{t - 1}G)] = \lambda {{\mathbb{E}}}_G[f(\mu_{t - 1} + \sigma_{t - 1}G)]

For the noise term, each coordinate of \sigma_t is a gaussian random
variable with 0 mean and variance

\frac{1}{n} \sum_v f(m^{t - 1})_v^2 \approx {{\mathbb{E}}}_{X, G}[f(\mu_{t - 1}X + \sigma_{t - 1}G)^2] = {{\mathbb{E}}}_{G}[f(\mu_{t - 1} + \sigma_{t - 1}G)^2]

It was shown in [4] that we can introduce a new
parameter \gamma_t s.t.
\gamma_t = \lambda^2 {{\mathbb{E}}}[f(\gamma_{t - 1} + \sqrt{\gamma_{t - 1}}G)]
As t \to \infty, turns out \mu_t = \frac{\gamma_t}{\lambda} and
\sigma_t^2 = \frac{\sigma_t}{\lambda^2}. To study the behavior of
m^t as t \to \infty, it is enough to track the evolution of
\gamma_t.

This heuristic analysis of AMP actually gives a phase transition at
\lambda = 1 (in fact, the analysis of AMP can be done rigorously as in [5]):

  • For \lambda < 1: If \gamma_t \approx 0, |\gamma_t + \sqrt{\gamma_t}G| < 1 w.h.p., thus we have \gamma_{t + 1} \approx \lambda^2 (\gamma_t) < \gamma_t. Taking t \to \infty, we have \gamma_{\infty} = 0, which means there AMP solution has no overlap with the ground truth.

  • For \lambda > 1: In this case, AMP’s solution has some correlation with the ground truth.

screenshot 2019-01-26 13.49.39

(Figure from [6])

Replica symmetry trick

Another way of obtaining the phase transition is via a non-rigorous
analytic method called the replica method. Although non-rigorous, this
method from statistical physics has been used to predict the fixed point
of many message passing algorithms and has the advantage of being easy
to simulate. In our case, we will see that we obtain the same phase
transition temperature as AMP above. The method is non-rigorous due to
several assumptions made during the computation.

Outline of replica method

Recall that we are interested in minizing the free energy of a given
system f(\beta, Y) = \frac{1}{\beta n} \log Z(\beta, Y) where Z is
the partition function as before:
Z(\beta, Y) = \sum_{x \in \{\pm 1\}^n} exp(-\beta H(Y, x)) and
H(Y, x) = -<Y, x^Tx> = -xYx^T = -\sum_{i, j} Y_{i, j}x_ix_j.

In replica method, Y is not fixed but a random variable. The
assumption is that as n \to \infty, free energy doesn’t vary with Y
too much, so we will look at the mean of f_Y to approximate free
energy of the system.

f(\beta) = \lim_{n \to \infty}\frac{1}{\beta n}{{\mathbb{E}}}_{Y}[\log Z(\beta, Y)]

f(\beta) is called the free energy density and the goal now is to
compute the free energy density as a function of only \beta , the
temperature of the system.

The replica method is first proposed as a simplification of the
computation of f(\beta)

It is a generally hard problem to compute f(\beta) in a clear way. A
naive attempt of approximate f(\beta) is to simply pull the log out
g(\beta) = \frac{1}{\beta n}\log {{\mathbb{E}}}_Y[Z(\beta, Y)]
Unfortunately g(\beta) and f(\beta) are quite different quantities,
at least when temperature is low. Intuitively, f(\beta) is looking at
system with a fixed Y while in g(\beta), x and Y are allowed to
fluctuate together. When the temperature is high, Y doesn’t play a big
roll in system thus they could be close. However, when temperature is
low, there could be a problems. Let \beta \to \infty,
f(\beta) \approx \int_Y (\beta x_Y Y x_Y)\mu(Y) dY,
g(\beta) \approx \log \int_Y exp(\beta x_J Y x_Y)\mu(Y)dY \approx \beta x^* Yx^*.

While {{\mathbb{E}}}_X[\log(f(X))] is hard to compute,
{{\mathbb{E}}}[f(X)^r] is a much easier quantity. The
replica trick starts from rewriting f(\beta) with moments of Z:
Recall that x^r \approx 1 + r \log x for r \approx 0 and
\ln(1 + x)\approx x, using this we can rewrite f(x) in the following
way:

Claim 1. Let f_r(\beta) = \frac{1}{r \beta n}\ln[{{\mathbb{E}}}_Y[Z(\beta, Y)^r]]
Then, f(\beta) = \lim_{r \to 0}f_r(\beta)

The idea of replica method is quite simple

  • Define a function f(r, \beta) for r \in \mathbb{Z}_+ s.t. f(r, \beta) = f_r(\beta) for all such r.

  • Extend f(r, \beta) analytically to all r \in {{\mathbb{R}}}_+ and take the limit of r \to 0.

The second step may sound crazy, but for some unexplained reason, it has
been surprisingly effective at making correct predictions.

The term replica comes from the way used to compute
{{\mathbb{E}}}[Z^r] in Claim 1. We expand the r-th moment
in terms of r replicas of the system

Z(\beta, Y)^r = (\sum_x exp(-\beta H(Y, x)))^r = \sum_{x^1, \cdots, x^r} \Pi_{k = 1}^r exp(-\beta H(Y, x^i))

For Rademacher spiked Wigner model

In this section, we will see how one can apply the replica trick to
obtain phase transition in the Rademacher spiked Wigner model. Recall
that given a hidden a \in \{\pm 1\}^n, the observable
Y = \frac{\lambda}{n}a^Ta + \frac{1}{\sqrt n} W where
W_{i, j} \sim \mathcal{N}(0, 1) and W_{i, i} \sim \mathcal{N}(0, 2).
We are interested in finding the smallest \lambda where we can still
recover a solution with some correlation to the ground truth a. Note
that \{W_{i, i}\} is not so important here as x_i^2 doesn’t carry
any information in this case.

Given by the posterior {{\mathbb{P}}}[x|Y], the system we
set up corresponding to Rademacher spiked Wigner model is the following:

  • the system consists of n particles and the interactions between
    each particle are give by Y

  • the signal to noise ratio \lambda as the inverse temperature
    \beta.

Following the steps above, we begin by computing
f(r, \beta) = \frac{1}{r\beta n}\ln{{\mathbb{E}}}_Y[Z^r]
for r \in \mathbb{Z}_+: Denote X^k = (x^k)^Tx^k where x^k is the
kth replica of the system.

{{\mathbb{E}}}_Y[Z^r] = \int_Y \sum_{x^1, \cdots, x^r} exp(\beta \sum_k <Y, X^k> \mu(Y) dY\\ = \int_Y \sum_{x^1, \cdots, x^r} exp(\beta <Y, \sum_k X^k>) \mu(Y) dY

We then simplify the above expression with a technical claim.

Claim 2. Let Y = A + \frac{1}{\sqrt{n}}W where A is a fixed matrix and
W is the GOE matrix defined as above. Then,
\int_Y exp(\beta<Y, X>) \mu(Y) dY = exp(\frac{\beta^2}{n}{{\|{X}\|_{F}}}^2 + \frac{\beta}{2} <A, X>)
for some constant C depending on distribution of Y.

Denote X = \sum_k X^k. Apply Claim 2 with
A = \frac{\beta}{n}a^Ta, we have
{{\mathbb{E}}}_Y[Z^r] = \sum_{x^1, \cdots, x^r} exp(\frac{\beta^2}{n}{{\|{X}\|_{F}}}^2 + \frac{\beta^2}{2n} <a^Ta, X>)
To understand the term inside exponent better, we can rewrite the inner
sum in terms of overlap between replicas:

{{\|{X}\|_{F}}}^2 = \sum_{i, j}X_{i, j}^2 = \sum_{i, j}(\sum_{k = 1}^r x^k_ix^k_j)^2 =\sum_{i, j}(\sum_{k = 1}^r x^k_ix^k_j)(\sum_{l = 1}^r x^l_ix^l_j)\\ = \sum_{k, l} (\sum_{i = 1}^n x^k_ix^{l}_i)^2 = \sum_{k, l} <x^k, x^l>^2

where the last equality follows from rearranging and switch the inner
and outer summations.

Using a similar trick, we can view the other term as

<a^Ta, X> = \sum_{i, j}\sum_{k = 1}^rx^k_ix^k_ja_ia_j = \sum_{k = 1}^r (\sum_{i = 1}^n a_ix^k_i)^2 = \sum_{k}<a, x^k>^2

Note that Q_{k, l} = <x^k, x^l> represents overlaps between the
k and lth replicas and Q_k = <a, x^k> represents the
overlaps between the kth replica and the ground truth vector.

In the end, we get for any integer r, (Equation 1):

\displaystyle f(r, \beta) = \frac{1}{r\beta n}\ln(\sum_{x^1, \cdots, x^r} exp(\frac{\beta^2}{n}\sum_{k, l}Q_{k, l}^2 + \frac{\beta^2}{2n}\sum_k Q_k^2)) \label{e:1}\\ = \frac{1}{r\beta n} \ln(\sum_{Q}\nu_{x^k}(Q)exp(\frac{\beta^2}{n}\sum_{k, l}Q_{k, l}^2 + \frac{\beta^2}{2n}\sum_k Q_k^2))

Our goal becomes to approximate this quantity. Intuitively, if we think
of Q_{k, l} as indices on a (r + 1) \times (r + 1) matrices, Q,
with Q(i,i) = 1, then Q is the average of n i.i.d matrices. So we
expect Q_{j, k} \in [\pm \frac{1}{n}] for j \neq k w.h.p. In the
remaining part, We find the correct Q via rewriting Equation 1.

Observe that by introducing a new variable Z_{k, l} for k \neq l and
using the property of gaussian intergal (Equation 4):

\label{e:4} exp(\frac{\beta^2}{n}Q_{k, l}^2) = \sqrt{\frac{n}{4\pi}}\int_{Z_{k, l}} exp(-\frac{n}{4}Z_{k, l}^2 + \beta Q_{k, l}Z_{k, l})dZ_k

\exp(\frac{\beta^2}{2n}Q_k^2) = \sqrt{\frac{1}{8\pi n}}\int_{Z_k}exp(-(2n)Z_k^2 + 2\beta Q_{k}Z_k)dZ_k
Replace each exp(\frac{\beta^2}{n}Q_{k, l}^2) by a such integral, we
have (Equation 2):

\begin{gathered} {{\mathbb{E}}}[Z^r] = \sum_{x^1, \cdots, x^r} exp(\frac{\beta^2}{n}\sum_{k, l}Q_{k, l}^2 + \frac{\beta^2}{2n}\sum_k Q_k^2) \label{e:2}\\ = C\sum_{x^1, \cdots, x^r} \exp(\beta^2 n)\int_{Z_{k, l}}exp(-\frac{n}{4}\sum_{k \neq l}Z_{k, l}^2 - \frac{n}{2}\sum_k Z_k^2 + \beta \sum_{k \neq l}Y_{k, l}Q_{k, l} + 2\beta\sum_kZ_k Q_k) dZ \\ =C\exp(\beta^n) \int_{Y_{k, l}}exp(-\frac{n}{4}\sum_{k \neq l}Y_{k, l}^2 - \frac{n}{2}\sum_k Z_k^2 + \ln(\sum_{x_1,\cdots, x_r}exp(\beta \sum_{k\neq l}Y_{k, l}Q_{k, l} + 2\beta\sum_kY_k Q_k)) dY \label{e:2}\end{gathered}

where C is the constant given by introducing gaussian intergals.

To compute the integral in (Equation 2), we need to cheat a little bit and take
n \to \infty before letting r \to 0. Note that free energy density
is defined as
f(\beta) = \lim_{n \to \infty}\lim_{r \to 0}\frac{1}{r\beta n}\ln {{\mathbb{E}}}_Y[Z(\beta, Y)^r]
This is the second assumption made in the replica method and it is
commonly believed that switching the order is okay here. Physically,
this is plausible because we believe intrinsic physical quantities
should not depend on the system size.

Now the Laplace method tells us when n \to \infty, the integral in (Equation 2) is dominated by the max of the exponent.

Theorem 1 (Laplace Method). Let h(x): {{\mathbb{R}}}^n \to {{\mathbb{R}}}then

\int e^{nh(x)} \approx e^{nh(x^*)}(\frac{2\pi}{n})^{\frac{d}{2}}\frac{1}{\sqrt{det(H)}}

where x^* = argmax_x \{h(x)\} and H is the Hessian of h evaluated at the point x^*.

Fix a pair of k, l and apply Laplace method with
h(Z_{k, l}) = -\frac{1}{2}\sum_{0 \leq k < l \leq r}Z_{k, l}^2 + \frac{1}{n}\ln(\sum_{x_1,\cdots, x_r}exp(\beta \sum_{k \neq l}Z_{k, l}Q_{k, l} + 2\beta\sum_kZ_k Q_k))
what’s left to do is to find the critical point of h. Taking the
derivatives gives
-Y_{k, l} + \frac{A(Z_{k, l})\beta Q_{k, l}}{n A(Z_{k, l})} = 0
where
A(Z_{k, l}) = \sum_{x_1,\cdots, x_r}exp(\beta \sum_{k \neq l}Z_{k, l}Q_{k, l} + \beta\sum_kY_k Q_k).

We now need to find a saddle point of h where the hessian is PSD. To
do that, we choose to assume the order of the replicas does not matter,
which is refer to as the replica symmetry case. 1 One simplest form
of Y is the following: \forall k, l > 0, Z_{k, l} = y and
Z_{k} = y for some y. This also implies that Q_{k, l} = q for some
q and y =\frac{\beta}{n} q

Plug this back in to Equation 2 gives: (Equation 3)

\label{e:3} {{\mathbb{E}}}[Z^r] = C\exp(\beta n)\exp(-\frac{n}{2}(\frac{r^2 - r}{2})y^2 - \frac{n^2}{2} + \ln(\sum_{x^i}\exp(y\beta\sum_{k \neq l}Q_{k, l} + 2y\beta \sum_k Q_k))

To obtain f(r, \beta), we only need to deal with the last term in
(Equation 3) as r \to 0. Using the fact that Q_{k, l} = y for all
k, l and using the same trick of introducing new gaussain integral as
in (Equation 4) we have
\lim_{r \to 0}\frac{1}{r}\ln(\sum_{x^i}\exp(y\beta\sum_{k \neq l}Q_{k, l} + n\beta \sum_k Q_k)) = -\beta + {{\mathbb{E}}}_{z \sim \mathcal{N}(0, 1)}[\log(2cosh(y\beta + \sqrt{y\beta}z))]

Using the fact that we want the solution to minimizes free energy,
taking the derivative of the current f w.r.t. y gives
\frac{y}{\beta} = n{{\mathbb{E}}}_z[tanh(y\beta + \sqrt{y\beta}z)]
which matches the fixed point of AMP. Plug in q and y will give us
f(\beta). The curve of f(\beta) looks like the Figure below, where
the solid line is the curve of f(\beta) with the given q and the
dotted line is the curve given by setting all variables 0.

screenshot 2019-01-26 13.54.49

 

References

[1] Amelia Perry, Alexander S Wein, Afonso S Bandeira, and Ankur Moitra. Optimality and sub-optimality of pca for spiked random matrices and synchronization.
arXiv preprint arXiv:1609.05573, 2016.
[2] D. Feral and S. Pech e. The Largest Eigenvalue of Rank One Deformation of Large Wigner Matrices. Communications in Mathematical Physics, 272:185–228, May 2007.
[3] Afonso S Bandeira, Amelia Perry, and Alexander S Wein. Notes on computational-to-statistical gaps: predictions using statistical physics. arXiv preprint arXiv:1803.11132, 2018.
[4] Yash Deshpande, Emmanuel Abbe, and Andrea Montanari. Asymptotic mutual information for thebinary stochastic block model. In
Information Theory (ISIT), 2016 IEEE International Symposium on, pages 185–189. IEEE, 2016.
[5] Adel Javanmard and Andrea Montanari. State evolution for general approximate message passing algorithms, with applications to spatial coupling. Information and Inference: A Journal of the IMA, 2(2):115–144, 2013.
[6] A. Perry, A. S. Wein, and A. S. Bandeira. Statistical limits of spiked tensor models.
ArXiv e-prints, December 2016.

  1. Turns out for this problem, replica symmetry is the only case. We
    will not talk about replica symmetry breaking here, which
    intuitively means we partition replicas into groups and re-curse. 

Comments are closed.

%d bloggers like this: