Source code for dibs.inference.svgd

import functools
import numpy as onp

import jax
import jax.numpy as jnp
from jax import jit, vmap, random, grad
from jax.tree_util import tree_map
from jax.scipy.special import logsumexp
from jax.example_libraries import optimizers

from dibs.inference.dibs import DiBS
from dibs.kernel import AdditiveFrobeniusSEKernel, JointAdditiveFrobeniusSEKernel
from dibs.metrics import ParticleDistribution
from dibs.utils.func import expand_by


[docs]class MarginalDiBS(DiBS): """ This class implements Stein Variational Gradient Descent (SVGD) (Liu and Wang, 2016) for DiBS inference (Lorch et al., 2021) of the marginal DAG posterior :math:`p(G | D)`. For joint inference of :math:`p(G, \\Theta | D)`, use the analogous class :class:`~dibs.inference.JointDiBS`. An SVGD update of tensor :math:`v` is defined as :math:`\\phi(v) \\propto \\sum_{u} k(v, u) \\nabla_u \\log p(u) + \\nabla_u k(u, v)` Args: x (ndarray): observations of shape ``[n_observations, n_vars]`` interv_mask (ndarray, optional): binary matrix of shape ``[n_observations, n_vars]`` indicating whether a given variable was intervened upon in a given sample (intervention = 1, no intervention = 0) graph_model: Model defining the prior :math:`\\log p(G)` underlying the inferred posterior. Object *has to implement one method*: ``unnormalized_log_prob_soft`` Example: :class:`~dibs.models.ErdosReniDAGDistribution` likelihood_model: Model defining the marginal likelihood :math:`\\log p(D | G)`` underlying the inferred posterior. Object *has to implement one method*: ``interventional_log_marginal_prob`` Example: :class:`~dibs.models.BGe` kernel: Class of kernel. *Has to implement the method* ``eval(u, v)``. Example: :class:`~dibs.kernel.AdditiveFrobeniusSEKernel` kernel_param (dict): kwargs to instantiate ``kernel`` optimizer (str): optimizer identifier optimizer_param (dict): kwargs to instantiate ``optimizer`` alpha_linear (float): slope of of linear schedule for inverse temperature :math:`\\alpha` of sigmoid in latent graph model :math:`p(G | Z)` beta_linear (float): slope of of linear schedule for inverse temperature :math:`\\beta` of constraint penalty in latent prior :math:`p(Z)` tau (float): constant Gumbel-softmax temperature parameter n_grad_mc_samples (int): number of Monte Carlo samples in gradient estimator for likelihood term :math:`p(\Theta, D | G)` n_acyclicity_mc_samples (int): number of Monte Carlo samples in gradient estimator for acyclicity constraint grad_estimator_z (str): gradient estimator :math:`\\nabla_Z` of expectation over :math:`p(G | Z)`; choices: ``score`` or ``reparam`` score_function_baseline (float): scale of additive baseline in score function (REINFORCE) estimator; ``score_function_baseline == 0.0`` corresponds to not using a baseline latent_prior_std (float): standard deviation of Gaussian prior over :math:`Z`; defaults to ``1/sqrt(k)`` """ def __init__(self, *, x, graph_model, likelihood_model, interv_mask=None, kernel=AdditiveFrobeniusSEKernel, kernel_param=None, optimizer="rmsprop", optimizer_param=None, alpha_linear=1.0, beta_linear=1.0, tau=1.0, n_grad_mc_samples=128, n_acyclicity_mc_samples=32, grad_estimator_z="score", score_function_baseline=0.0, latent_prior_std=None, verbose=False): # handle mutable default args if kernel_param is None: kernel_param = {"h": 5.0} if optimizer_param is None: optimizer_param = {"stepsize": 0.005} # handle interv mask in observational case if interv_mask is None: interv_mask = jnp.zeros_like(x, dtype=jnp.int32) # init DiBS superclass methods super().__init__( x=x, interv_mask=interv_mask, log_graph_prior=graph_model.unnormalized_log_prob_soft, log_joint_prob=likelihood_model.interventional_log_marginal_prob, alpha_linear=alpha_linear, beta_linear=beta_linear, tau=tau, n_grad_mc_samples=n_grad_mc_samples, n_acyclicity_mc_samples=n_acyclicity_mc_samples, grad_estimator_z=grad_estimator_z, score_function_baseline=score_function_baseline, latent_prior_std=latent_prior_std, verbose=verbose, ) self.likelihood_model = likelihood_model self.graph_model = graph_model # functions for post-hoc likelihood evaluations self.eltwise_log_marginal_likelihood_observ = vmap(lambda g, x_ho: likelihood_model.interventional_log_marginal_prob(g, None, x_ho, jnp.zeros_like(x_ho), None), (0, None), 0) self.eltwise_log_marginal_likelihood_interv = vmap(lambda g, x_ho, interv_msk_ho: likelihood_model.interventional_log_marginal_prob(g, None, x_ho, interv_msk_ho, None), (0, None, None), 0) self.kernel = kernel(**kernel_param) if optimizer == 'gd': self.opt = optimizers.sgd(optimizer_param['stepsize']) elif optimizer == 'rmsprop': self.opt = optimizers.rmsprop(optimizer_param['stepsize']) else: raise ValueError() def _sample_initial_random_particles(self, *, key, n_particles, n_dim=None): """ Samples random particles to initialize SVGD Args: key (ndarray): rng key n_particles (int): number of particles inferred n_dim (int): size of latent dimension :math:`k`. Defaults to ``n_vars``, s.t. :math:`k = d` Returns: batch of latent tensors ``[n_particles, d, k, 2]`` """ # default full rank if n_dim is None: n_dim = self.n_vars # like prior std = self.latent_prior_std or (1.0 / jnp.sqrt(n_dim)) # sample key, subk = random.split(key) z = random.normal(subk, shape=(n_particles, self.n_vars, n_dim, 2)) * std return z def _f_kernel(self, x_latent, y_latent): """ Evaluates kernel Args: x_latent (ndarray): latent tensor of shape ``[d, k, 2]`` y_latent (ndarray): latent tensor of shape ``[d, k, 2]`` Returns: kernel value of shape ``[1, ]`` """ return self.kernel.eval(x=x_latent, y=y_latent) def _f_kernel_mat(self, x_latents, y_latents): """ Computes pairwise kernel matrix Args: x_latents (ndarray): latent tensor of shape ``[A, d, k, 2]`` y_latents (ndarray): latent tensor of shape ``[B, d, k, 2]`` Returns: kernel values of shape ``[A, B]`` """ return vmap(vmap(self._f_kernel, (None, 0), 0), (0, None), 0)(x_latents, y_latents) def _eltwise_grad_kernel_z(self, x_latents, y_latent): """ Computes gradient :math:`\\nabla_Z k(Z, Z')` elementwise for each provided particle :math:`Z` in batch ``x_latents` Args: x_latents (ndarray): batch of latent particles for :math:`Z` of shape ``[n_particles, d, k, 2]`` y_latent (ndarray): single latent particle :math:`Z'` ``[d, k, 2]`` Returns: batch of gradients of shape ``[n_particles, d, k, 2]`` """ grad_kernel_z = grad(self._f_kernel, 0) return vmap(grad_kernel_z, (0, None), 0)(x_latents, y_latent) def _z_update(self, single_z, kxx, z, grad_log_prob_z): """ Computes SVGD update for ``single_z`` particle given the kernel values ``kxx`` and the :math:`d/dZ` gradients of the target density for each of the available particles Args: single_z (ndarray): single latent tensor ``[d, k, 2]``, which is the Z particle being updated kxx (ndarray): pairwise kernel values for all particles ``[n_particles, n_particles]`` z (ndarray): all latent particles ``[n_particles, d, k, 2]`` grad_log_prob_z (ndarray): gradients of all Z particles w.r.t target density of shape ``[n_particles, d, k, 2]`` Returns transform vector of shape ``[d, k, 2]`` for the particle ``single_z`` """ # compute terms in sum weighted_gradient_ascent = kxx[..., None, None, None] * grad_log_prob_z repulsion = self._eltwise_grad_kernel_z(z, single_z) # average and negate (for optimizer) return - (weighted_gradient_ascent + repulsion).mean(axis=0) def _parallel_update_z(self, *args): """ Vectorizes :func:`~dibs.inference.MarginalDiBS._z_update` for all available particles in batched first input dim (``single_z``) Otherwise, same inputs as :func:`~dibs.inference.MarginalDiBS._z_update`. """ return vmap(self._z_update, (0, 1, None, None), 0)(*args) def _svgd_step(self, t, opt_state_z, key, sf_baseline): """ Performs a single SVGD step in the DiBS framework, updating all :math:`Z` particles jointly. Args: t (int): step opt_state_z: optimizer state for latent :math:`Z` particles; contains ``[n_particles, d, k, 2]`` key (ndarray): prng key sf_baseline (ndarray): batch of baseline values of shape ``[n_particles, ]`` in case score function gradient is used Returns: the updated inputs ``opt_state_z``, ``key``, ``sf_baseline`` """ z = self.get_params(opt_state_z) # [n_particles, d, k, 2] n_particles = z.shape[0] # d/dz log p(D | z) key, *batch_subk = random.split(key, n_particles + 1) dz_log_likelihood, sf_baseline = self.eltwise_grad_z_likelihood(z, None, sf_baseline, t, jnp.array(batch_subk)) # here `None` is a placeholder for theta (in the joint inference case) # since this is an inherited function from the general `DiBS` class # d/dz log p(z) (acyclicity) key, *batch_subk = random.split(key, n_particles + 1) dz_log_prior = self.eltwise_grad_latent_prior(z, jnp.array(batch_subk), t) # d/dz log p(z, D) = d/dz log p(z) + log p(D | z) dz_log_prob = dz_log_prior + dz_log_likelihood # k(z, z) for all particles kxx = self._f_kernel_mat(z, z) # transformation phi() applied in batch to each particle individually phi_z = self._parallel_update_z(z, kxx, z, dz_log_prob) # apply transformation # `x += stepsize * phi`; the phi returned is negated for SVGD opt_state_z = self.opt_update(t, phi_z, opt_state_z) return opt_state_z, key, sf_baseline # this is the crucial @jit @functools.partial(jit, static_argnums=(0, 2)) def _svgd_loop(self, start, n_steps, init): return jax.lax.fori_loop(start, start + n_steps, lambda i, args: self._svgd_step(i, *args), init)
[docs] def sample(self, *, key, n_particles, steps, n_dim_particles=None, callback=None, callback_every=None): """ Use SVGD with DiBS to sample ``n_particles`` particles :math:`G` from the marginal posterior :math:`p(G | D)` as defined by the BN model ``self.inference_model`` Arguments: key (ndarray): prng key n_particles (int): number of particles to sample steps (int): number of SVGD steps performed n_dim_particles (int): latent dimensionality :math:`k` of particles :math:`Z = \{ U, V \}` with :math:`U, V \\in \\mathbb{R}^{k \\times d}`. Default is ``n_vars`` callback: function to be called every ``callback_every`` steps of SVGD. callback_every: if ``None``, ``callback`` is only called after particle updates have finished Returns: batch of samples :math:`G \\sim p(G | D)` of shape ``[n_particles, n_vars, n_vars]`` """ # randomly sample initial particles key, subk = random.split(key) init_z = self._sample_initial_random_particles(key=subk, n_particles=n_particles, n_dim=n_dim_particles) # initialize score function baseline (one for each particle) n_particles, _, n_dim, _ = init_z.shape sf_baseline = jnp.zeros(n_particles) if self.latent_prior_std is None: self.latent_prior_std = 1.0 / jnp.sqrt(n_dim) # maintain updated particles with optimizer state opt_init, self.opt_update, get_params = self.opt self.get_params = jit(get_params) opt_state_z = opt_init(init_z) """Execute particle update steps for all particles in parallel using `vmap` functions""" # faster if for-loop is functionally pure and compiled, so only interrupt for callback callback_every = callback_every or steps for t in (range(0, steps, callback_every) if steps else range(0)): # perform sequence of SVGD steps opt_state_z, key, sf_baseline = self._svgd_loop(t, callback_every, (opt_state_z, key, sf_baseline)) # callback if callback: z = self.get_params(opt_state_z) callback( dibs=self, t=t + callback_every, zs=z, ) # retrieve transported particles z_final = jax.device_get(self.get_params(opt_state_z)) # as alpha is large, we can convert the latents Z to their corresponding graphs G g_final = self.particle_to_g_lim(z_final) return g_final
[docs] def get_empirical(self, g): """ Converts batch of binary (adjacency) matrices into *empirical* particle distribution where mixture weights correspond to counts/occurrences Args: g (ndarray): batch of graph samples ``[n_particles, d, d]`` with binary values Returns: :class:`~dibs.metrics.ParticleDistribution`: particle distribution of graph samples and associated log probabilities """ N, _, _ = g.shape unique, counts = onp.unique(g, axis=0, return_counts=True) # empirical distribution using counts logp = jnp.log(counts) - jnp.log(N) return ParticleDistribution(logp=logp, g=unique)
[docs] def get_mixture(self, g): """ Converts batch of binary (adjacency) matrices into *mixture* particle distribution, where mixture weights correspond to unnormalized target (i.e. posterior) probabilities Args: g (ndarray): batch of graph samples ``[n_particles, d, d]`` with binary values Returns: :class:`~dibs.metrics.ParticleDistribution`: particle distribution of graph samples and associated log probabilities """ N, _, _ = g.shape # mixture weighted by respective marginal probabilities eltwise_log_marginal_target = vmap(lambda single_g: self.log_joint_prob(single_g, None, self.x, self.interv_mask, None), 0, 0) logp = eltwise_log_marginal_target(g) logp -= logsumexp(logp) return ParticleDistribution(logp=logp, g=g)
[docs]class JointDiBS(DiBS): """ This class implements Stein Variational Gradient Descent (SVGD) (Liu and Wang, 2016) for DiBS inference (Lorch et al., 2021) of the marginal DAG posterior :math:`p(G | D)`. For marginal inference of :math:`p(G | D)`, use the analogous class :class:`~dibs.inference.MarginalDiBS`. An SVGD update of tensor :math:`v` is defined as :math:`\\phi(v) \\propto \\sum_{u} k(v, u) \\nabla_u \\log p(u) + \\nabla_u k(u, v)` Args: x (ndarray): observations of shape ``[n_observations, n_vars]`` interv_mask (ndarray, optional): binary matrix of shape ``[n_observations, n_vars]`` indicating whether a given variable was intervened upon in a given sample (intervention = 1, no intervention = 0) graph_model: Model defining the prior :math:`\\log p(G)` underlying the inferred posterior. Object *has to implement one method*: ``unnormalized_log_prob_soft`` Example: :class:`~dibs.models.ErdosReniDAGDistribution` likelihood_model: Model defining the joint likelihood :math:`\\log p(\\Theta, D | G) = \\log p(\\Theta | G) + \\log p(D | G, \\Theta)`` underlying the inferred posterior. Object *has to implement one method*: ``interventional_log_joint_prob`` Example: :class:`~dibs.models.LinearGaussian` kernel: Class of kernel. *Has to implement the method* ``eval(u, v)``. Example: :class:`~dibs.kernel.JointAdditiveFrobeniusSEKernel` kernel_param (dict): kwargs to instantiate ``kernel`` optimizer (str): optimizer identifier optimizer_param (dict): kwargs to instantiate ``optimizer`` alpha_linear (float): slope of of linear schedule for inverse temperature :math:`\\alpha` of sigmoid in latent graph model :math:`p(G | Z)` beta_linear (float): slope of of linear schedule for inverse temperature :math:`\\beta` of constraint penalty in latent prior :math:`p(Z)` tau (float): constant Gumbel-softmax temperature parameter n_grad_mc_samples (int): number of Monte Carlo samples in gradient estimator for likelihood term :math:`p(\Theta, D | G)` n_acyclicity_mc_samples (int): number of Monte Carlo samples in gradient estimator for acyclicity constraint grad_estimator_z (str): gradient estimator :math:`\\nabla_Z` of expectation over :math:`p(G | Z)`; choices: ``score`` or ``reparam`` score_function_baseline (float): scale of additive baseline in score function (REINFORCE) estimator; ``score_function_baseline == 0.0`` corresponds to not using a baseline latent_prior_std (float): standard deviation of Gaussian prior over :math:`Z`; defaults to ``1/sqrt(k)`` """ def __init__(self, *, x, graph_model, likelihood_model, interv_mask=None, kernel=JointAdditiveFrobeniusSEKernel, kernel_param=None, optimizer="rmsprop", optimizer_param=None, alpha_linear=0.05, beta_linear=1.0, tau=1.0, n_grad_mc_samples=128, n_acyclicity_mc_samples=32, grad_estimator_z="reparam", score_function_baseline=0.0, latent_prior_std=None, verbose=False): # handle mutable default args if kernel_param is None: kernel_param = {"h_latent": 5.0, "h_theta": 500.0} if optimizer_param is None: optimizer_param = {"stepsize": 0.005} # handle interv mask in observational case if interv_mask is None: interv_mask = jnp.zeros_like(x, dtype=jnp.int32) # init DiBS superclass methods super().__init__( x=x, interv_mask=interv_mask, log_graph_prior=graph_model.unnormalized_log_prob_soft, log_joint_prob=likelihood_model.interventional_log_joint_prob, alpha_linear=alpha_linear, beta_linear=beta_linear, tau=tau, n_grad_mc_samples=n_grad_mc_samples, n_acyclicity_mc_samples=n_acyclicity_mc_samples, grad_estimator_z=grad_estimator_z, score_function_baseline=score_function_baseline, latent_prior_std=latent_prior_std, verbose=verbose, ) self.likelihood_model = likelihood_model self.graph_model = graph_model # functions for post-hoc likelihood evaluations self.eltwise_log_likelihood_observ = vmap(lambda g, theta, x_ho: likelihood_model.interventional_log_joint_prob(g, theta, x_ho, jnp.zeros_like(x_ho), None), (0, 0, None), 0) self.eltwise_log_likelihood_interv = vmap(lambda g, theta, x_ho, interv_msk_ho: likelihood_model.interventional_log_joint_prob(g, theta, x_ho, interv_msk_ho, None), (0, 0, None, None), 0) self.kernel = kernel(**kernel_param) if optimizer == 'gd': self.opt = optimizers.sgd(optimizer_param['stepsize']) elif optimizer == 'rmsprop': self.opt = optimizers.rmsprop(optimizer_param['stepsize']) else: raise ValueError() def _sample_initial_random_particles(self, *, key, n_particles, n_dim=None): """ Samples random particles to initialize SVGD Args: key (ndarray): rng key n_particles (int): number of particles inferred n_dim (int): size of latent dimension :math:`k`. Defaults to ``n_vars``, s.t. :math:`k = d` Returns: batch of latent tensors ``[n_particles, d, k, 2]`` """ # default full rank if n_dim is None: n_dim = self.n_vars # std like Gaussian prior over Z std = self.latent_prior_std or (1.0 / jnp.sqrt(n_dim)) # sample from parameter prior key, subk = random.split(key) z = random.normal(subk, shape=(n_particles, self.n_vars, n_dim, 2)) * std key, subk = random.split(key) theta = self.likelihood_model.sample_parameters(key=subk, n_particles=n_particles, n_vars=self.n_vars) return z, theta def _f_kernel(self, x_latent, x_theta, y_latent, y_theta): """ Evaluates kernel Args: x_latent (ndarray): latent tensor of shape ``[d, k, 2]`` x_theta (Any): parameter PyTree y_latent (ndarray): latent tensor of shape ``[d, k, 2]`` y_theta (Any): parameter PyTree Returns: kernel value of shape ``[1, ]`` """ return self.kernel.eval( x_latent=x_latent, x_theta=x_theta, y_latent=y_latent, y_theta=y_theta) def _f_kernel_mat(self, x_latents, x_thetas, y_latents, y_thetas): """ Computes pairwise kernel matrix Args: x_latents (ndarray): latent tensor of shape ``[A, d, k, 2]`` x_thetas (Any): parameter PyTree with batch size ``A`` as leading dim y_latents (ndarray): latent tensor of shape ``[B, d, k, 2]`` y_thetas (Any): parameter PyTree with batch size ``B`` as leading dim Returns: kernel values of shape ``[A, B]`` """ return vmap(vmap(self._f_kernel, (None, None, 0, 0), 0), (0, 0, None, None), 0)(x_latents, x_thetas, y_latents, y_thetas) def _eltwise_grad_kernel_z(self, x_latents, x_thetas, y_latent, y_theta): """ Computes gradient :math:`\\nabla_Z k((Z, \\Theta), (Z', \\Theta'))` elementwise for each provided particle :math:`(Z, \\Theta)` in batch (``x_latents`, ``x_thetas``) Args: x_latents (ndarray): batch of latent particles for :math:`Z` of shape ``[n_particles, d, k, 2]`` x_thetas (Any): batch of parameter PyTrees for :math:`\\Theta` with leading dim ``n_particles`` y_latent (ndarray): single latent particle :math:`Z'` ``[d, k, 2]`` y_theta (Any): single parameter PyTree for :math:`\\Theta'` Returns: batch of gradients of shape ``[n_particles, d, k, 2]`` """ grad_kernel_z = grad(self._f_kernel, 0) return vmap(grad_kernel_z, (0, 0, None, None), 0)(x_latents, x_thetas, y_latent, y_theta) def _eltwise_grad_kernel_theta(self, x_latents, x_thetas, y_latent, y_theta): """ Computes gradient :math:`\\nabla_{\\Theta} k((Z, \\Theta), (Z', \\Theta'))` elementwise for each provided particle :math:`(Z, \\Theta)` in batch (``x_latents`, ``x_thetas``) Args: x_latents (ndarray): batch of latent particles for :math:`Z` of shape ``[n_particles, d, k, 2]`` x_thetas (Any): batch of parameter PyTrees for :math:`\\Theta` with leading dim ``n_particles`` y_latent (ndarray): single latent particle :math:`Z'` ``[d, k, 2]`` y_theta (Any): single parameter PyTree for :math:`\\Theta'` Returns: batch of gradient PyTrees with leading dim ``n_particles`` """ grad_kernel_theta = grad(self._f_kernel, 1) return vmap(grad_kernel_theta, (0, 0, None, None), 0)(x_latents, x_thetas, y_latent, y_theta) def _z_update(self, single_z, single_theta, kxx, z, theta, grad_log_prob_z): """ Computes SVGD update for ``single_z`` of a particle tuple (``single_z``, ``single_theta``) particle given the kernel values ``kxx`` and the :math:`d/dZ` gradients of the target density for each of the available particles Args: single_z (ndarray): single latent tensor ``[d, k, 2]``, which is the :math:`\\Z` particle being updated single_theta (Any): single parameter PyTree, the :math:`\\Theta` particle of the :math:`\\Z` particle being updated kxx (ndarray): pairwise kernel values for all particles, of shape ``[n_particles, n_particles]`` z (ndarray): all latent particles ``[n_particles, d, k, 2]`` theta (Any): all theta particles as PyTree with leading dim `n_particles` grad_log_prob_z (ndarray): gradients of all Z particles w.r.t target density of shape ``[n_particles, d, k, 2]`` Returns transform vector of shape ``[d, k, 2]`` for the particle ``single_z`` """ # compute terms in sum weighted_gradient_ascent = kxx[..., None, None, None] * grad_log_prob_z repulsion = self._eltwise_grad_kernel_z(z, theta, single_z, single_theta) # average and negate (for optimizer) return - (weighted_gradient_ascent + repulsion).mean(axis=0) def _parallel_update_z(self, *args): """ Vectorizes :func:`~dibs.inference.JointDiBS._z_update` for all available particles in batched first and second input dim (``single_z``, ``single_theta``) Otherwise, same inputs as :func:`~dibs.inference.JointDiBS._z_update`. """ return vmap(self._z_update, (0, 0, 1, None, None, None), 0)(*args) def _theta_update(self, single_z, single_theta, kxx, z, theta, grad_log_prob_theta): """ Computes SVGD update for ``single_theta`` of a particle tuple (``single_z``, ``single_theta``) particle given the kernel values ``kxx`` and the :math:`d/d\\Theta` gradients of the target density for each of the available particles. Analogous to :func:`dibs.inference.JointDiBS._z_update` but for updating :math:`\Theta`. Args: single_z (ndarray): single latent tensor ``[d, k, 2]``, which is the particle particle of the :math:`\\Theta` particle being updated single_theta (Any): single parameter PyTree, the :math:`\\Theta`, which is the :math:`\\Theta` particle being updated kxx (ndarray): pairwise kernel values for all particles, of shape ``[n_particles, n_particles]`` z (ndarray): all latent particles ``[n_particles, d, k, 2]`` theta (Any): all theta particles as PyTree with leading dim `n_particles` grad_log_prob_theta (ndarray): gradients of all :math:`\\Theta` particles w.r.t target density of shape ``[n_particles, d, k, 2]`` Returns: transform vector PyTree with leading dim ``n_particles`` for the particle ``single_theta`` """ # compute terms in sum weighted_gradient_ascent = tree_map( lambda leaf_theta_grad: expand_by(kxx, leaf_theta_grad.ndim - 1) * leaf_theta_grad, grad_log_prob_theta) repulsion = self._eltwise_grad_kernel_theta(z, theta, single_z, single_theta) # average and negate (for optimizer) return tree_map( lambda grad_asc_leaf, repuls_leaf: - (grad_asc_leaf + repuls_leaf).mean(axis=0), weighted_gradient_ascent, repulsion) def _parallel_update_theta(self, *args): """ Vectorizes :func:`~dibs.inference.JointDiBS._theta_update` for all available particles in batched first and second input dim (``single_z``, ``single_theta``). Otherwise, same inputs as :func:`~dibs.inference.JointDiBS._theta_update`. """ return vmap(self._theta_update, (0, 0, 1, None, None, None), 0)(*args) def _svgd_step(self, t, opt_state_z, opt_state_theta, key, sf_baseline): """ Performs a single SVGD step in the DiBS framework, updating all :math:`(Z, \\Theta)` particles jointly. Args: t (int): step opt_state_z: optimizer state for latent :math:`Z` particles; contains ``[n_particles, d, k, 2]`` opt_state_theta: optimizer state for parameter :math:`\\Theta` particles; contains PyTree with ``n_particles`` leading dim key (ndarray): prng key sf_baseline (ndarray): batch of baseline values of shape ``[n_particles, ]`` in case score function gradient is used Returns: the updated inputs ``opt_state_z``, ``opt_state_theta``, ``key``, ``sf_baseline`` """ z = self.get_params(opt_state_z) # [n_particles, d, k, 2] theta = self.get_params(opt_state_theta) # PyTree with `n_particles` leading dim n_particles = z.shape[0] # d/dtheta log p(theta, D | z) key, *batch_subk = random.split(key, n_particles + 1) dtheta_log_prob = self.eltwise_grad_theta_likelihood(z, theta, t, jnp.array(batch_subk)) # d/dz log p(theta, D | z) key, *batch_subk = random.split(key, n_particles + 1) dz_log_likelihood, sf_baseline = self.eltwise_grad_z_likelihood(z, theta, sf_baseline, t, jnp.array(batch_subk)) # d/dz log p(z) (acyclicity) key, *batch_subk = random.split(key, n_particles + 1) dz_log_prior = self.eltwise_grad_latent_prior(z, jnp.array(batch_subk), t) # d/dz log p(z, theta, D) = d/dz log p(z) + log p(theta, D | z) dz_log_prob = dz_log_prior + dz_log_likelihood # k((z, theta), (z, theta)) for all particles kxx = self._f_kernel_mat(z, theta, z, theta) # transformation phi() applied in batch to each particle individually phi_z = self._parallel_update_z(z, theta, kxx, z, theta, dz_log_prob) phi_theta = self._parallel_update_theta(z, theta, kxx, z, theta, dtheta_log_prob) # apply transformation # `x += stepsize * phi`; the phi returned is negated for SVGD opt_state_z = self.opt_update(t, phi_z, opt_state_z) opt_state_theta = self.opt_update(t, phi_theta, opt_state_theta) return opt_state_z, opt_state_theta, key, sf_baseline # this is the crucial @jit @functools.partial(jit, static_argnums=(0, 2)) def _svgd_loop(self, start, n_steps, init): return jax.lax.fori_loop(start, start + n_steps, lambda i, args: self._svgd_step(i, *args), init)
[docs] def sample(self, *, key, n_particles, steps, n_dim_particles=None, callback=None, callback_every=None): """ Use SVGD with DiBS to sample ``n_particles`` particles :math:`(G, \\Theta)` from the joint posterior :math:`p(G, \\Theta | D)` as defined by the BN model ``self.likelihood_model`` Arguments: key (ndarray): prng key n_particles (int): number of particles to sample steps (int): number of SVGD steps performed n_dim_particles (int): latent dimensionality :math:`k` of particles :math:`Z = \{ U, V \}` with :math:`U, V \\in \\mathbb{R}^{k \\times d}`. Default is ``n_vars`` callback: function to be called every ``callback_every`` steps of SVGD. callback_every: if ``None``, ``callback`` is only called after particle updates have finished Returns: tuple of shape (``[n_particles, n_vars, n_vars]``, ``PyTree``) where ``PyTree`` has leading dimension ``n_particles``: batch of samples :math:`G, \\Theta \\sim p(G, \\Theta | D)` """ # randomly sample initial particles key, subk = random.split(key) init_z, init_theta = self._sample_initial_random_particles(key=subk, n_particles=n_particles, n_dim=n_dim_particles) # initialize score function baseline (one for each particle) n_particles, _, n_dim, _ = init_z.shape sf_baseline = jnp.zeros(n_particles) if self.latent_prior_std is None: self.latent_prior_std = 1.0 / jnp.sqrt(n_dim) # maintain updated particles with optimizer state opt_init, self.opt_update, get_params = self.opt self.get_params = jit(get_params) opt_state_z = opt_init(init_z) opt_state_theta = opt_init(init_theta) """Execute particle update steps for all particles in parallel using `vmap` functions""" # faster if for-loop is functionally pure and compiled, so only interrupt for callback callback_every = callback_every or steps for t in (range(0, steps, callback_every) if steps else range(0)): # perform sequence of SVGD steps opt_state_z, opt_state_theta, key, sf_baseline = self._svgd_loop(t, callback_every, (opt_state_z, opt_state_theta, key, sf_baseline)) # callback if callback: z = self.get_params(opt_state_z) theta = self.get_params(opt_state_theta) callback( dibs=self, t=t + callback_every, zs=z, thetas=theta, ) # retrieve transported particles z_final = jax.device_get(self.get_params(opt_state_z)) theta_final = jax.device_get(self.get_params(opt_state_theta)) # as alpha is large, we can convert the latents Z to their corresponding graphs G g_final = self.particle_to_g_lim(z_final) return g_final, theta_final
[docs] def get_empirical(self, g, theta): """ Converts batch of binary (adjacency) matrices and parameters into *empirical* particle distribution where mixture weights correspond to counts/occurrences Args: g (ndarray): batch of graph samples ``[n_particles, d, d]`` with binary values theta (Any): PyTree with leading dim ``n_particles`` Returns: :class:`~dibs.metrics.ParticleDistribution`: particle distribution of graph and parameter samples and associated log probabilities """ N, _, _ = g.shape # since theta continuous, each particle (G, theta) is unique always logp = - jnp.log(N) * jnp.ones(N) return ParticleDistribution(logp=logp, g=g, theta=theta)
[docs] def get_mixture(self, g, theta): """ Converts batch of binary (adjacency) matrices and particles into *mixture* particle distribution, where mixture weights correspond to unnormalized target (i.e. posterior) probabilities Args: g (ndarray): batch of graph samples ``[n_particles, d, d]`` with binary values theta (Any): PyTree with leading dim ``n_particles`` Returns: :class:`~dibs.metrics.ParticleDistribution`: particle distribution of graph and parameter samples and associated log probabilities """ N, _, _ = g.shape # mixture weighted by respective joint probabilities eltwise_log_joint_target = vmap(lambda single_g, single_theta: self.log_joint_prob(single_g, single_theta, self.x, self.interv_mask, None), (0, 0), 0) logp = eltwise_log_joint_target(g, theta) logp -= logsumexp(logp) return ParticleDistribution(logp=logp, g=g, theta=theta)