Analytical derivations for non-exchangeable prior

Table of Contents

Introduction

Assume \(\mathbf{y}\) is \(n \times 1\), \(\mathbf{x}\) is \(n \times p\), \(m\) is the (maximum) number of causal effects.

Assume the following model, which has a non-exchangeable prior on sparse regression coefficients.

\[ p(\mathbf{y} \mid \mathbf{x}, \mathbf{w}) = N(\mathbf{y}; \mathbf{x} \mathbf{w}', v^{-1} \mathbf{I}) \] \[ \mathbf{w} = \sum_{k=1}^m z_{kj} b_{kj} \]

\[ p(b_{kj} \mid z_k = 1, v, v_b) = N(b_{kj}; 0, v^{-1} v_b^{-1}) \] \[ p(b_{kj} \mid z_k = 0) = \delta(b_{kj}) \]

\[ p(z_k \mid \mathbf{p}) = \mathrm{Multinomial}(z_k; 1, \mathbf{p}) \]

Assume the following variational approximation:

\[ q(b_{kj} \mid z_{kj} = 1, \mu, \phi) = N(b_{kj}; \mu_{kj}, \phi_{kj}^{-1}) \] \[ q(b_{kj} \mid z_{kj} = 0, \mu, \phi) = \delta(b_{kj}) \]

\[ q(z_k \mid \pi_k) = \mathrm{Multinomial}(z_k; 1, \mathbf{\pi}_k) \]

Our goal is to derive an analytical expression for the evidence lower bound.

\[ E_q[\ln p(\mathbf{y} \mid \mathbf{x}, \mathbf{w})] - KL(q(\mathbf{b} \mid \mathbf{z}) \Vert p(\mathbf{b} \mid \mathbf{z})) - KL(q(\mathbf{z} \Vert p(\mathbf{z})))\]

Specifically, we seek an analytical expression for the first term:

\[ E_q \left[-\frac{1}{2}\ln(v) - \frac{v}{2} \left(\sum_i y_i - \sum_j x_{ij} w_j \right)^2 \right]\]

Induction

Let \(w_{rj} = z_{rj} b_{rj}\) and let \(w_j^{(k)} = \sum_{r=1}^k z_{rj} b_{rj}\).

Suppose \(k = 1\). Then:

\[ E_q \left[ \sum_i y_i - 2 y_i \sum_j x_{ij} w_{1j} + \left(\sum_j x_{ij} w_{1j}\right)^2 \right] \]

\[ = \sum_i y_i^2 - 2 y_i \sum_j x_{ij} E_q[w_{1j}] + \sum_j x_{ij}^2 E_q[w_{1j}^2] + \sum_{j \neq k} x_{ij} x_{ik} w_{1j} w_{1k} \]

\[ = \sum_i y_i - 2 y_i \sum_j x_{ij} E_q[w_{1j}] + \sum_j x_{ij}^2 \left(E_q[w_{1j}]\right)^2 + \sum_j x_{ij}^2 V_q[w_{1j}] + \sum_{j \neq k} x_{ij} x_{ik} E_q[w_{1j} w_{1k}] \]

But the last term is zero, because:

  • if \(z_{1j} = 0, w_{1j} = 0\)
  • if \(z_{1k} = 0, w_{1k} = 0\)
  • if \(z_{1j} = z_{1k} = 1, q(z) = 0\)

Therefore,

\[ E_q[\cdot] = \sum_i \left(y_i - \sum_j x_{ij} E_q[w_{1j}] \right)^2 + \sum_j x_{ij}^2 V_q[w_{1j}] \]

Now suppose \(k > 1\). Then:

\[ E_q \left[ \sum_i \left( y_i - \sum_j x_{ij} w_{j}^{(k-1)} - \sum_j x_{ij} w_{kj}\right)^2 \right] \]

\[ = E_q \left[ \sum_i y_i^2 - 2 y_i \sum_j x_{ij} w_{j}^{(k-1)} - 2 y_i \sum_j x_{ij} w_{kj} + \left(\sum_j x_{ij} w_j^{(k-1)}\right)^2 + \left(x_{ij} w_{kj}\right)^2 + 2 \left(\sum_j x_{ij} w_j^{(k-1)}\right) \left(x_{ij} w_{kj}\right)\right] \]

\[ = \sum_i \left(y_i - \sum_j x_{ij} E_q[w_j^{(k)}]\right)^2 + \sum_j x^2_{ij} V_q[w_j^{(k)}] + 2 \sum_{r, s} x_{ir} x_{is} E_q[w_r^{(k-1)} w_{ks}] \]

Induction on cross terms

Case L = 1

Dropping the first index for clarity and considering two indices \(j, k \in [p]\), if \(j = k\):

\[ E_q[z_j^2 b_j^2] = \pi_j \left(V_q[b_j \mid z_j = 1] + (E_q[b_j \mid z_j = 1])^2 \right) \]

\[ = \pi_j (\phi_j^{-1} + \mu_j^2) \]

If \(j \neq k\), we have to condition \(b_j b_k\) on 4 possible values \((z_{1j} z_{2j})\).

But the term is 0 conditioned on \(z_j = 0, z_k = 0\) because \(q(b_j = 0 \mid z_j = 0) = 1\), and is 0 conditioned on \(z_j = 1, z_k = 1\) because in that case \(q(z) = 0\).

\[ E_q[z_j z_k b_j b_k] = \pi_j E_q[b_j b_k \mid z_j = 1] + \pi_k E_q[b_j b_k \mid z_k = 1] \]

But conditioned on \(z_j = 1\), \(z_k = 0, b_k = 0\) (and analagous conditioned on \(z_k = 1\)) so:

\[ E_q[w_j w_k] = 0 \]

Simulate to verify this result:

def sample_ww(logits, mean, prec):
  z = tf.multinomial(logits, 1)
  z = tf.reshape(tf.one_hot(z, logits.shape[-1]), tf.shape(mean))
  z = tf.cast(z, tf.float32)
  b = mean + tf.random_normal(mean.shape) * tf.sqrt(tf.reciprocal(prec))
  w = tf.reduce_sum(z * b, axis=0, keep_dims=True)
  ww = tf.matmul(w, w, transpose_a=True)
  return ww

def empirical_cov(l, p, num_samples=1000):
  logits = tf.get_variable('logits', initializer=tf.random_normal([l, p]))
  mean = tf.get_variable('mean', initializer=tf.random_normal([l, p]))
  prec = tf.get_variable('prec', initializer=tf.nn.softplus(tf.random_normal([l, p])))

  ww = np.zeros((num_samples, p, p))
  with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    for i in range(num_samples):
      ww[i] = sess.run(sample_ww(logits, mean, prec))
  return ww.mean(axis=0)
with tf.Graph().as_default():
  cov = empirical_cov(l=1, p=2, num_samples=10000)
cov

66c3aaf0-aa23-4dbf-b066-7d703725d27c

Case L = 2

If \(j = k\):

\[ E_q[(z_{1j} b_{1j} + z_{2j} b_{2j})^2] = E_q[(z_{1j} b_{1j})^2] + E_q[(z_{2j} b_{2j})^2] + 2 E_q[z_{1j} z_{2j} b_{1j} b_{2j}] \]

From above, the first two terms are:

\[ E_q[(z_{1j} b_{1j})^2] = \pi_{1j} (\phi_{1j}^{-1} + \mu_{1j}^2) \]

\[ E_q[(z_{2j} b_{2j})^2] = \pi_{2j} (\phi_{2j}^{-1} + \mu_{2j}^2) \]

The final term is an expectation conditioned over 4 possible values \((z_{1j} z_{2j})\). But it is only non-zero if \(z_{1j} = 1, z_{2j} = 1\), so

\[ 2 E_q[z_{1j} z_{2j} b_{1j} b_{2j}] = 2 \pi_{1j} \pi_{2j} \mu_{1j} \mu_{2j} \]

If \(j \neq k\):

\[ E_q[(z_{1j} b_{1j} + z_{2j} b_{2j}) (z_{1k} b_{1k} + z_{2k} b_{2k})] \]

From above, terms involving \(z_{1j} z_{1k}\) and \(z_{2j} z_{2k}\) vanish.

\[ = E_q[z_{1j} z_{2k} b_{1j} b_{2k}] + E_q[z_{1k} z_{2j} b_{1k} b_{2j}] \]

Each term is an expectation conditioned over 4 possible values \((z_{1j}, z_{2k})\). But it is non-zero only conditioned on \(z_{1j} = 1, z_{2k} = 1\). If \(z_{1j} = 0\), \(b_{1j} = 0\), and similar for \(z_{2k} = 0\).

\[ E_q[z_{1j} z_{2k} b_{1j} b_{2k}] = \pi_{1j} \pi_{2k} \mu_{1j} \mu_{2k} \]

\[ E_q[z_{1k} z_{2j} b_{1k} b_{2j}] = \pi_{1k} \pi_{2j} \mu_{1k} \mu_{2j} \]

General case

\[ E_q \left[\left( \sum_r z_{rj} b_{rj} \right) \left( \sum_s z_{sk} z_{sk} \right) \right] \]

We sum terms over pairs \(r, s\). This can be conceptualized as constructing an \(m \times m\) matrix for each pair \(j, k\), where entry \(r, s\) contains the corresponding term, and summing its entries. Therefore, we can simply reuse the results for case \(L = 2\):

If \(j = k, r = s\):

\[ \pi_{rj} (\phi_{rj}^{-1} + \mu_{rj}^2) \]

If \(j = k, r \neq s\):

\[ 2 \pi_{rj} \pi_{sj} \mu_{rj} \mu_{sj} \]

If \(j \neq k, r = s\):

\[ 0 \]

If \(j \neq k, r \neq s\):

\[ \pi_{rj} \pi_{sk} \mu_{rj} \mu_{sk} + \pi_{rk} \pi_{sj} \mu_{rk} \mu_{sj} \]

To efficiently implement this, notice that the case \(r \neq s\) is the same for all pairs \(j, k\). On the diagonal, we recover the factor of two simply by summing the terms corresponding to \(r, s\) and \(s, r\).

On the diagonal we have extra terms \(\pi_{rj} (\phi_{rj}^{-1} + \mu_{rj}^2)\).

def analytical_cov(logits, mean, prec):
  probs = tf.sigmoid(logits)
  probs /= tf.reduce_sum(probs, axis=1, keep_dims=True)
  var = tf.reduce_sum(probs * (tf.square(mean) * tf.reciprocal(prec)), axis=0)
  cov = tf.matrix_set_diag(tf.eye(tf.shape(probs)[1]), var)
  if probs.shape[0] == 1:
    return cov
  else:
    cov += tf.matmul(probs, probs, transpose_a=True)
    cov += tf.matmul(mean, mean, transpose_a=True)
    return cov

def compare_cov(l, p, num_samples=50):
  logits = tf.get_variable('logits', initializer=tf.random_normal([l, p]))
  mean = tf.get_variable('mean', initializer=tf.random_normal([l, p]))
  prec = tf.get_variable('prec', initializer=tf.nn.softplus(tf.random_normal([l, p])))

  ww = np.zeros((num_samples, p, p))
  with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    cov = sess.run(analytical_cov(logits, mean, prec))
    for i in range(num_samples):
      ww[i] = sess.run(sample_ww(logits, mean, prec))
  return {'true': cov, 'mean': ww.mean(axis=0), 'std': ww.std(axis=0)}
with tf.Graph().as_default():
  result = compare_cov(l=1, p=5, num_samples=500)
result
{'mean': array([[ 0.51065322,  0.        ,  0.        ,  0.        ,  0.        ],
          [ 0.        ,  0.46582288,  0.        ,  0.        ,  0.        ],
          [ 0.        ,  0.        ,  0.05170335,  0.        ,  0.        ],
          [ 0.        ,  0.        ,  0.        ,  0.49491202,  0.        ],
          [ 0.        ,  0.        ,  0.        ,  0.        ,  1.00800762]]),
   'std': array([[ 2.28961071,  0.        ,  0.        ,  0.        ,  0.        ],
          [ 0.        ,  1.94926498,  0.        ,  0.        ,  0.        ],
          [ 0.        ,  0.        ,  0.35635244,  0.        ,  0.        ],
          [ 0.        ,  0.        ,  0.        ,  1.93968206,  0.        ],
          [ 0.        ,  0.        ,  0.        ,  0.        ,  2.17985675]]),
   'true': array([[ 1.18962514,  0.        ,  0.        ,  0.        ,  0.        ],
          [ 0.        ,  0.4958204 ,  0.        ,  0.        ,  0.        ],
          [ 0.        ,  0.        ,  0.08744346,  0.        ,  0.        ],
          [ 0.        ,  0.        ,  0.        ,  0.15774874,  0.        ],
          [ 0.        ,  0.        ,  0.        ,  0.        ,  0.33383933]], dtype=float32)}
with tf.Graph().as_default():
  result = compare_cov(l=2, p=5, num_samples=1000)
result
{'mean': array([[  1.56900097e+00,  -3.64141234e-01,   1.65107749e-02,
            -4.45378973e-02,  -2.15203249e-01],
          [ -3.64141234e-01,   1.94331921e+00,  -2.05560835e-02,
            -8.16864129e-02,  -3.33085588e-02],
          [  1.65107749e-02,  -2.05560835e-02,   1.89392084e-01,
            -1.24557875e-06,  -1.59193349e-02],
          [ -4.45378973e-02,  -8.16864129e-02,  -1.24557875e-06,
             3.59574653e-01,  -1.07985522e-02],
          [ -2.15203249e-01,  -3.33085588e-02,  -1.59193349e-02,
            -1.07985522e-02,   9.24471020e-01]]),
   'std': array([[ 3.08918563,  1.38056227,  0.36943229,  0.59176396,  0.97111137],
          [ 1.38056227,  3.28394083,  0.51141142,  0.63430121,  0.64153024],
          [ 0.36943229,  0.51141142,  0.75163003,  0.12413185,  0.36090543],
          [ 0.59176396,  0.63430121,  0.12413185,  1.52855069,  0.34238875],
          [ 0.97111137,  0.64153024,  0.36090543,  0.34238875,  2.80481511]]),
   'true': array([[ 2.27550507, -0.21468884,  0.46644443,  0.20610106, -2.23284817],
          [-0.21468884,  4.57976007, -0.15748027,  0.97013932,  2.04361153],
          [ 0.46644443, -0.15748027,  0.17025512,  0.04313061, -0.54789758],
          [ 0.20610106,  0.97013932,  0.04313061,  0.39109433,  0.29355159],
          [-2.23284817,  2.04361153, -0.54789758,  0.29355159,  4.82095623]], dtype=float32)}

Author: Abhishek Sarkar

Created: 2018-03-18 Sun 19:50

Validate