Johnson Lindenstrauss Lemma

The Johnson-Lindenstrauss (JL) Lemma is a important result in mathematics that provides a way to reduce the dimensionality of data points while preserving the pairwise distances between them, up to a certain distortion. In this blog post, we introduce the JL Lemma and present a proof of the theorem.

The proof is based on the work of Dasgupta et al.’s paper An elementary proof of a theorem of Johnson and Lindenstrauss.. For a more detailed explanation, please refer to the original paper.

JL Lemma

For any set O of n points in Rd, there exists a map f:RdRk, such that for all oi,ojP:

(1ϵ)oioj22f(oi)f(oj)22(1+ϵ)oioj22

where 0<ϵ<1 and k is a positive integer satisfying:

k>243ϵ22ϵ3logn

Let A be a random matrix of size k×d, where each entry is independently drawn from the standard normal distribution N(0,1). Using this random matrix, we can construct a mapping that satisfies the JL Lemma. For any point pRd, the mapping is defined as:

f(p)=Akp

Proof of JL Lemma

To prove the JL Lemma, we first establish some auxiliary lemmas and inequalities. For convenience, we will omit the subscript 2 in the norm, so 2 refers to the 2-norm by default.

Lemma 1

For any point pRd, let Y=Akp. Then, we have:

E[Y2]=p2

Proof of Lemma 1

From the definition of Yi, we have:

Yi=1kj=1dAijpj

Now, we compute E[Y2]:

E[Y2]=E[i=1k1k(j=1dAijpj)2]=E[i=1k1kj=1dt=1dAijAitpjpt]=i=1k1kj=1dt=1dpjptE[AijAit]

  • if it, then: E[AijAit]=E[Aij]E[Ait]=0
  • if i=t, then: E[AijAit]=E[Aij2]=Var(Aij)+E2[Aij]=1

Thus, we get:

E[Y2]=i=1k1kj=1dpj2=i=1k1kp2=p2

Lemma 2

Let X be a random variable distributed as N(0,1). For a constant λ12, we have: E[eλX2]=112λ

Proof of Lemma 2

The PDF of the standard normal distribution is:

f(x)=12πex22

We now compute E[eλx2]:

E[eλx2]=+12πeλx2ex22dx=12π+e(12λ)x22dx

Next, let y=x12λ, so that dy=12λdx. Substituting this into the integral, we get:

E[eλx2]=112λ12π+ey22dy

Using the well-known result for the standard normal distribution:

12π+ey22dy=1

Therefore:

E[eλx2]=112λ

Lemma 3

For any point pRd, let Y=Akp. The following inequalities hold:

Pr[Y2(1+ϵ)p2]1n2Pr[Y2(1ϵ)p2]1n2

Proof of Lemma 3

We begin by simplifying the inequalities. From the definition of Y, we have:

Yi=1kj=1dAijpj

We now compute Y2:

Y2=i=1k(1kj=1dAijpj)2=i=1k1k(j=1dAijpj)2

Since the standard normal distribution is a 2-stable distribution, using the will-known property, we have:

Y2=i=1k1kp2Xi2=p2ki=1kXi2

where Xi is a random variable distributed as the standard normal distribution. Substituting this back, the inequalities we need to prove reduce to:

Pr[i=1kXi2k(1+ϵ)]1n2Pr[i=1kXi2k(1ϵ)]1n2

For the first inequality, let λ>0, then:

Pr[i=1kXi2k(1+ϵ)]=Pr[eλi=1kXi2eλk(1+ϵ)]

Using the Markov’s inequality, we have:

Pr[eλi=1kXi2eλk(1+ϵ)]E[eλi=1kXi2]eλk(1+ϵ)=Ek[eλX12]eλk(1+ϵ)

Let λ=ϵ2(1+ϵ)12. By Lemma 2, we have:

Ek[eλX12]eλk(1+ϵ)=(112λ)keλk(1+ϵ)=[(1+ϵ)eϵ]k2

Using the Taylor expansion for ln(1+ϵ):

ln(1+ϵ)=ϵϵ22+ϵ33ϵ44+

We have:

1+ϵeϵϵ22+ϵ33

Thus:

[(1+ϵ)eϵ]k2(eϵϵ22+ϵ33eϵ)k2=ek(3ϵ22ϵ3)12

If k satisfies:

k>243ϵ22ϵ3logn

then:

ek(3ϵ22ϵ3)12e2logn=1n2

Therefore:

Pr[i=1kXi2k(1+ϵ)]1n2

Similarly, for the second inequality, let λ>0, then:

Pr[i=1kXi2k(1ϵ)]=Pr[eλi=1kXi2eλk(1ϵ)]E[eλi=1kXi2]eλk(1ϵ)=Ek[eλX12]eλk(1ϵ)

Let λ=ϵ2(1ϵ), so that λ=ϵ2(ϵ1)<0<12. We have:

Ek[eλX12]eλk(1ϵ)=(11+2λ)keλk(1ϵ)=[(1ϵ)eϵ]k2

Similarly, using the Taylor expansion for ln(1ϵ):

ln(1ϵ)=ϵϵ22ϵ33

we have:

1ϵeϵϵ22

Thus:

[(1ϵ)eϵ]k2ekϵ24<e6ϵ23ϵ22ϵ3logn<e2logn=1n2

Therefore:

Pr[i=1kXi2k(1ϵ)]1n2

Proof of the JL Lemma

Given any point pRd, we will use Lemma 3 and Boole’s inequality to show that the JL Lemma holds.

From Lemma 3, for any point pRd, the following holds:

Pr[f(p)2[(1ϵ)p2,(1+ϵ)p2]]2n2

Since this probability holds for any point pRd, it also holds for any pair of points u=oioj, where oi,ojO.

Furthermore, using the linearity of the projection, we have:

f(o1o2)=f(o1)f(o2)

Therefore, applying Lemma 3 to the difference between two points oi and oj, we obtain:

Pr[f(oi)f(oj)2[(1ϵ)oioj2,(1+ϵ)oioj2]]2n2

The number of pairs of points oi,ojO is (n2). Thus, using Boole’s inequality, the probability that at least one pair of points falls outside the desired error bound is at most:

n(n1)22n2=11n

Thus, with probability at least 1n>0, all pairs of points in O will fall within the desired error bounds, completing the proof of the JL Lemma.