import jax.numpy as jnp
from jax import vmap, random, grad
from jax.scipy.special import logsumexp
from jax.nn import sigmoid, log_sigmoid
import jax.lax as lax
from jax.tree_util import tree_map
from dibs.graph_utils import acyclic_constr_nograd
from dibs.utils.func import expand_by, zero_diagonal
[docs]class DiBS:
"""
This class implements the backbone for DiBS, i.e. all gradient estimators and sampling
components. Any inference method in the DiBS framework should inherit from this class.
Args:
x (ndarray): matrix of shape ``[n_observations, n_vars]`` of i.i.d. observations of the variables
interv_mask (ndarray): binary matrix of shape ``[n_observations, n_vars]`` indicating
whether a given variable was intervened upon in a given sample (intervention = 1, no intervention = 0)
log_graph_prior (callable):
function implementing prior :math:`\\log p(G)` of soft adjacency matrix of
edge probabilities.
For example: :func:`~dibs.graph.ErdosReniDAGDistribution.unnormalized_log_prob_soft`
or usually bound in e.g. :func:`~dibs.graph.LinearGaussian.log_graph_prior`
log_joint_prob (callable):
function implementing joint likelihood :math:`\\log p(\Theta, D | G)`
of parameters and observations given the discrete graph adjacency matrix
For example: :func:`dibs.models.LinearGaussian.interventional_log_joint_prob`.
When inferring the marginal posterior :math:`p(G | D)` via a closed-form
marginal likelihood :math:`\\log p(D | G)`, the same function signature has to be
satisfied (simply ignoring :math:`\\Theta`)
alpha_linear (float): slope of of linear schedule for inverse temperature :math:`\\alpha`
of sigmoid in latent graph model :math:`p(G | Z)`
beta_linear (float): slope of of linear schedule for inverse temperature :math:`\\beta`
of constraint penalty in latent prio :math:`p(Z)`
tau (float): constant Gumbel-softmax temperature parameter
n_grad_mc_samples (int): number of Monte Carlo samples in gradient estimator
for likelihood term :math:`p(\Theta, D | G)`
n_acyclicity_mc_samples (int): number of Monte Carlo samples in gradient estimator
for acyclicity constraint
grad_estimator_z (str): gradient estimator :math:`\\nabla_Z` of expectation over :math:`p(G | Z)`;
choices: ``score`` or ``reparam``
score_function_baseline (float): scale of additive baseline in score function (REINFORCE) estimator;
``score_function_baseline == 0.0`` corresponds to not using a baseline
latent_prior_std (float): standard deviation of Gaussian prior over :math:`Z`; defaults to ``1/sqrt(k)``
"""
def __init__(self, *,
x,
interv_mask,
log_graph_prior,
log_joint_prob,
alpha_linear=0.05,
beta_linear=1.0,
tau=1.0,
n_grad_mc_samples=128,
n_acyclicity_mc_samples=32,
grad_estimator_z='reparam',
score_function_baseline=0.0,
latent_prior_std=None,
verbose=False):
self.x = x
self.interv_mask = interv_mask
self.n_vars = x.shape[-1]
self.log_graph_prior = log_graph_prior
self.log_joint_prob = log_joint_prob
self.alpha = lambda t: (alpha_linear * t)
self.beta = lambda t: (beta_linear * t)
self.tau = tau
self.n_grad_mc_samples = n_grad_mc_samples
self.n_acyclicity_mc_samples = n_acyclicity_mc_samples
self.grad_estimator_z = grad_estimator_z
self.score_function_baseline = score_function_baseline
self.latent_prior_std = latent_prior_std
self.verbose = verbose
"""
Backbone functionality
"""
[docs] def particle_to_g_lim(self, z):
"""
Returns :math:`G` corresponding to :math:`\\alpha = \\infty` for particles `z`
Args:
z (ndarray): latent variables ``[..., d, k, 2]``
Returns:
graph adjacency matrices of shape ``[..., d, d]``
"""
u, v = z[..., 0], z[..., 1]
scores = jnp.einsum('...ik,...jk->...ij', u, v)
g_samples = (scores > 0).astype(jnp.int32)
# mask diagonal since it is explicitly not modeled
return zero_diagonal(g_samples)
[docs] def sample_g(self, p, subk, n_samples):
"""
Sample Bernoulli matrix according to matrix of probabilities
Args:
p (ndarray): matrix of probabilities ``[d, d]``
n_samples (int): number of samples
subk (ndarray): rng key
Returns:
an array of matrices sampled according to ``p`` of shape ``[n_samples, d, d]``
"""
n_vars = p.shape[-1]
g_samples = random.bernoulli(
subk, p=p, shape=(n_samples, n_vars, n_vars)).astype(jnp.int32)
# mask diagonal since it is explicitly not modeled
return zero_diagonal(g_samples)
[docs] def particle_to_soft_graph(self, z, eps, t):
"""
Gumbel-softmax / concrete distribution using Logistic(0,1) samples ``eps``
Args:
z (ndarray): a single latent tensor :math:`Z` of shape ``[d, k, 2]```
eps (ndarray): random i.i.d. Logistic(0,1) noise of shape ``[d, d]``
t (int): step
Returns:
Gumbel-softmax sample of adjacency matrix [d, d]
"""
scores = jnp.einsum('...ik,...jk->...ij', z[..., 0], z[..., 1])
# soft reparameterization using gumbel-softmax/concrete distribution
# eps ~ Logistic(0,1)
soft_graph = sigmoid(self.tau * (eps + self.alpha(t) * scores))
# mask diagonal since it is explicitly not modeled
return zero_diagonal(soft_graph)
[docs] def particle_to_hard_graph(self, z, eps, t):
"""
Bernoulli sample of :math:`G` using probabilities implied by latent ``z``
Args:
z (ndarray): a single latent tensor :math:`Z` of shape ``[d, k, 2]``
eps (ndarray): random i.i.d. Logistic(0,1) noise of shape ``[d, d]``
t (int): step
Returns:
Gumbel-max (hard) sample of adjacency matrix ``[d, d]``
"""
scores = jnp.einsum('...ik,...jk->...ij', z[..., 0], z[..., 1])
# simply take hard limit of sigmoid in gumbel-softmax/concrete distribution
hard_graph = ((eps + self.alpha(t) * scores) > 0.0).astype(jnp.float32)
# mask diagonal since it is explicitly not modeled
return zero_diagonal(hard_graph)
"""
Generative graph model p(G | Z)
"""
[docs] def edge_probs(self, z, t):
"""
Edge probabilities encoded by latent representation
Args:
z (ndarray): latent tensors :math:`Z` ``[..., d, k, 2]``
t (int): step
Returns:
edge probabilities of shape ``[..., d, d]``
"""
u, v = z[..., 0], z[..., 1]
scores = jnp.einsum('...ik,...jk->...ij', u, v)
probs = sigmoid(self.alpha(t) * scores)
# mask diagonal since it is explicitly not modeled
return zero_diagonal(probs)
[docs] def edge_log_probs(self, z, t):
"""
Edge log probabilities encoded by latent representation
Args:
z (ndarray): latent tensors :math:`Z` ``[..., d, k, 2]``
t (int): step
Returns:
tuple of tensors ``[..., d, d], [..., d, d]`` corresponding to ``log(p)`` and ``log(1-p)``
"""
u, v = z[..., 0], z[..., 1]
scores = jnp.einsum('...ik,...jk->...ij', u, v)
log_probs, log_probs_neg = log_sigmoid(self.alpha(t) * scores), log_sigmoid(self.alpha(t) * -scores)
# mask diagonal since it is explicitly not modeled
# NOTE: this is not technically log(p), but the way `edge_log_probs_` is used, this is correct
return zero_diagonal(log_probs), zero_diagonal(log_probs_neg)
[docs] def latent_log_prob(self, single_g, single_z, t):
"""
Log likelihood of generative graph model
Args:
single_g (ndarray): single graph adjacency matrix ``[d, d]``
single_z (ndarray): single latent tensor ``[d, k, 2]``
t (int): step
Returns:
log likelihood :math:`log p(G | Z)` of shape ``[1,]``
"""
# [d, d], [d, d]
log_p, log_1_p = self.edge_log_probs(single_z, t)
# [d, d]
log_prob_g_ij = single_g * log_p + (1 - single_g) * log_1_p
# [1,] # diagonal is masked inside `edge_log_probs`
log_prob_g = jnp.sum(log_prob_g_ij)
return log_prob_g
[docs] def eltwise_grad_latent_log_prob(self, gs, single_z, t):
"""
Gradient of log likelihood of generative graph model w.r.t. :math:`Z`
i.e. :math:`\\nabla_Z \\log p(G | Z)`
Batched over samples of :math:`G` given a single :math:`Z`.
Args:
gs (ndarray): batch of graph matrices ``[n_graphs, d, d]``
single_z (ndarray): latent variable ``[d, k, 2]``
t (int): step
Returns:
batch of gradients of shape ``[n_graphs, d, k, 2]``
"""
dz_latent_log_prob = grad(self.latent_log_prob, 1)
return vmap(dz_latent_log_prob, (0, None, None), 0)(gs, single_z, t)
"""
Estimators for scores of log p(theta, D | Z)
"""
[docs] def eltwise_log_joint_prob(self, gs, single_theta, rng):
"""
Joint likelihood :math:`\\log p(\\Theta, D | G)` batched over samples of :math:`G`
Args:
gs (ndarray): batch of graphs ``[n_graphs, d, d]``
single_theta (Any): single parameter PyTree
rng (ndarray): for mini-batching ``x`` potentially
Returns:
batch of logprobs of shape ``[n_graphs, ]``
"""
return vmap(self.log_joint_prob, (0, None, None, None, None), 0)(gs, single_theta, self.x, self.interv_mask,
rng)
[docs] def log_joint_prob_soft(self, single_z, single_theta, eps, t, subk):
"""
This is the composition of :math:`\\log p(\\Theta, D | G) `
and :math:`G(Z, U)` (Gumbel-softmax graph sample given :math:`Z`)
Args:
single_z (ndarray): single latent tensor ``[d, k, 2]``
single_theta (Any): single parameter PyTree
eps (ndarray): i.i.d Logistic noise of shape ``[d, d]``
t (int): step
subk (ndarray): rng key
Returns:
logprob of shape ``[1, ]``
"""
soft_g_sample = self.particle_to_soft_graph(single_z, eps, t)
return self.log_joint_prob(soft_g_sample, single_theta, self.x, self.interv_mask, subk)
#
# Estimators for score d/dZ log p(theta, D | Z)
# (i.e. w.r.t the latent embeddings Z for graph G)
#
[docs] def eltwise_grad_z_likelihood(self, zs, thetas, baselines, t, subkeys):
"""
Computes batch of estimators for score :math:`\\nabla_Z \\log p(\\Theta, D | Z)`
Selects corresponding estimator used for the term :math:`\\nabla_Z E_{p(G|Z)}[ p(\\Theta, D | G) ]`
and executes it in batch.
Args:
zs (ndarray): batch of latent tensors :math:`Z` ``[n_particles, d, k, 2]``
thetas (Any): batch of parameters PyTree with ``n_particles`` as leading dim
baselines (ndarray): array of score function baseline values of shape ``[n_particles, ]``
Returns:
tuple batch of (gradient estimates, baselines) of shapes ``[n_particles, d, k, 2], [n_particles, ]``
"""
# select the chosen gradient estimator
if self.grad_estimator_z == 'score':
grad_z_likelihood = self.grad_z_likelihood_score_function
elif self.grad_estimator_z == 'reparam':
grad_z_likelihood = self.grad_z_likelihood_gumbel
else:
raise ValueError(f'Unknown gradient estimator `{self.grad_estimator_z}`')
# vmap
return vmap(grad_z_likelihood, (0, 0, 0, None, 0), (0, 0))(zs, thetas, baselines, t, subkeys)
[docs] def grad_z_likelihood_score_function(self, single_z, single_theta, single_sf_baseline, t, subk):
"""
Score function estimator (aka REINFORCE) for the score :math:`\\nabla_Z \\log p(\\Theta, D | Z)`
Uses the same :math:`G \\sim p(G | Z)` samples for expectations in numerator and denominator.
This does not use :math:`\\nabla_G \\log p(\\Theta, D | G)` and is hence applicable when
the gradient w.r.t. the adjacency matrix is not defined (as e.g. for the BGe score).
Args:
single_z (ndarray): single latent tensor ``[d, k, 2]``
single_theta (Any): single parameter PyTree
single_sf_baseline (ndarray): ``[1, ]``
t (int): step
subk (ndarray): rng key
Returns:
tuple of gradient, baseline ``[d, k, 2], [1, ]``
"""
# [d, d]
p = self.edge_probs(single_z, t)
n_vars, n_dim = single_z.shape[0:2]
# [n_grad_mc_samples, d, d]
subk, subk_ = random.split(subk)
g_samples = self.sample_g(p, subk_, self.n_grad_mc_samples)
# same MC samples for numerator and denominator
n_mc_numerator = self.n_grad_mc_samples
n_mc_denominator = self.n_grad_mc_samples
# [n_grad_mc_samples, ]
subk, subk_ = random.split(subk)
logprobs_numerator = self.eltwise_log_joint_prob(g_samples, single_theta, subk_)
logprobs_denominator = logprobs_numerator
# variance_reduction
logprobs_numerator_adjusted = lax.cond(
self.score_function_baseline <= 0.0,
lambda _: logprobs_numerator,
lambda _: logprobs_numerator - single_sf_baseline,
operand=None)
# [d * k * 2, n_grad_mc_samples]
grad_z = self.eltwise_grad_latent_log_prob(g_samples, single_z, t) \
.reshape(self.n_grad_mc_samples, n_vars * n_dim * 2) \
.transpose((1, 0))
# stable computation of exp/log/divide
# [d * k * 2, ] [d * k * 2, ]
log_numerator, sign = logsumexp(a=logprobs_numerator_adjusted, b=grad_z, axis=1, return_sign=True)
# []
log_denominator = logsumexp(logprobs_denominator, axis=0)
# [d * k * 2, ]
stable_sf_grad = sign * jnp.exp(log_numerator - jnp.log(n_mc_numerator) - log_denominator + jnp.log(n_mc_denominator))
# [d, k, 2]
stable_sf_grad_shaped = stable_sf_grad.reshape(n_vars, n_dim, 2)
# update baseline
single_sf_baseline = (self.score_function_baseline * logprobs_numerator.mean(0) +
(1 - self.score_function_baseline) * single_sf_baseline)
return stable_sf_grad_shaped, single_sf_baseline
[docs] def grad_z_likelihood_gumbel(self, single_z, single_theta, single_sf_baseline, t, subk):
"""
Reparameterization estimator for the score :math:`\\nabla_Z \\log p(\\Theta, D | Z)`
sing the Gumbel-softmax / concrete distribution reparameterization trick.
Uses the same :math:`G \\sim p(G | Z)` samples for expectations in numerator and denominator.
This **does** require a well-defined gradient
:math:`\\nabla_G \\log p(\\Theta, D | G)` and is hence not applicable when
the gradient w.r.t. the adjacency matrix is not defined for Gumbel-relaxations
of the discrete adjacency matrix.
Any (marginal) likelihood expressible as a function of ``g[:, j]`` and ``theta`` ,
e.g. using the vector of (possibly soft) parent indicators as a mask, satisfies this.
Examples are: ``dibs.models.LinearGaussian`` and ``dibs.models.DenseNonlinearGaussian``
See also e.g. http://proceedings.mlr.press/v108/zheng20a/zheng20a.pdf
Args:
single_z (ndarray): single latent tensor ``[d, k, 2]``
single_theta (Any): single parameter PyTree
single_sf_baseline (ndarray): ``[1, ]``
t (int): step
subk (ndarray): rng key
Returns:
tuple of gradient, baseline ``[d, k, 2], [1, ]``
"""
n_vars = single_z.shape[0]
# same MC samples for numerator and denominator
n_mc_numerator = self.n_grad_mc_samples
n_mc_denominator = self.n_grad_mc_samples
# sample Logistic(0,1) as randomness in reparameterization
subk, subk_ = random.split(subk)
eps = random.logistic(subk_, shape=(self.n_grad_mc_samples, n_vars, n_vars))
# [n_grad_mc_samples, ]
# since we don't backprop per se, it leaves us with the option of having
# `soft` and `hard` versions for evaluating the non-grad p(.))
subk, subk_ = random.split(subk)
# [d, k, 2], [d, d], [n_grad_mc_samples, d, d], [1,], [1,] -> [n_grad_mc_samples]
logprobs_numerator = vmap(self.log_joint_prob_soft, (None, None, 0, None, None), 0)(single_z, single_theta, eps, t, subk_)
logprobs_denominator = logprobs_numerator
# [n_grad_mc_samples, d, k, 2]
# d/dx log p(theta, D | G(x, eps)) for a batch of `eps` samples
# use the same minibatch of data as for other log prob evaluation (if using minibatching)
# [d, k, 2], [d, d], [n_grad_mc_samples, d, d], [1,], [1,] -> [n_grad_mc_samples, d, k, 2]
grad_z = vmap(grad(self.log_joint_prob_soft, 0), (None, None, 0, None, None), 0)(single_z, single_theta, eps, t, subk_)
# stable computation of exp/log/divide
# [d, k, 2], [d, k, 2]
log_numerator, sign = logsumexp(a=logprobs_numerator[:, None, None, None], b=grad_z, axis=0, return_sign=True)
# []
log_denominator = logsumexp(logprobs_denominator, axis=0)
# [d, k, 2]
stable_grad = sign * jnp.exp(log_numerator - jnp.log(n_mc_numerator) - log_denominator + jnp.log(n_mc_denominator))
return stable_grad, single_sf_baseline
#
# Estimators for score d/dtheta log p(theta, D | Z)
# (i.e. w.r.t the conditional distribution parameters)
#
[docs] def eltwise_grad_theta_likelihood(self, zs, thetas, t, subkeys):
"""
Computes batch of estimators for the score :math:`\\nabla_{\\Theta} \\log p(\\Theta, D | Z)`,
i.e. w.r.t the conditional distribution parameters.
Uses the same :math:`G \\sim p(G | Z)` samples for expectations in numerator and denominator.
This does not use :math:`\\nabla_G \\log p(\\Theta, D | G)` and is hence applicable when
the gradient w.r.t. the adjacency matrix is not defined (as e.g. for the BGe score).
Analogous to ``eltwise_grad_z_likelihood`` but gradient w.r.t :math:`\\Theta` instead of :math:`Z`
Args:
zs (ndarray): batch of latent tensors Z of shape ``[n_particles, d, k, 2]``
thetas (Any): batch of parameter PyTree with ``n_mc_samples`` as leading dim
Returns:
batch of gradients in form of ``thetas`` PyTree with ``n_particles`` as leading dim
"""
return vmap(self.grad_theta_likelihood, (0, 0, None, 0), 0)(zs, thetas, t, subkeys)
[docs] def grad_theta_likelihood(self, single_z, single_theta, t, subk):
"""
Computes Monte Carlo estimator for the score :math:`\\nabla_{\\Theta} \\log p(\\Theta, D | Z)`
Uses hard samples of :math:`G`, but a soft reparameterization like for :math:`\\nabla_Z` is also possible.
Uses the same :math:`G \\sim p(G | Z)` samples for expectations in numerator and denominator.
Args:
single_z (ndarray): single latent tensor ``[d, k, 2]``
single_theta (Any): single parameter PyTree
t (int): step
subk (ndarray): rng key
Returns:
parameter gradient PyTree
"""
# [d, d]
p = self.edge_probs(single_z, t)
# [n_grad_mc_samples, d, d]
g_samples = self.sample_g(p, subk, self.n_grad_mc_samples)
# same MC samples for numerator and denominator
n_mc_numerator = self.n_grad_mc_samples
n_mc_denominator = self.n_grad_mc_samples
# [n_mc_numerator, ]
subk, subk_ = random.split(subk)
logprobs_numerator = self.eltwise_log_joint_prob(g_samples, single_theta, subk_)
logprobs_denominator = logprobs_numerator
# PyTree shape of `single_theta` with additional leading dimension [n_mc_numerator, ...]
# d/dtheta log p(theta, D | G) for a batch of G samples
# use the same minibatch of data as for other log prob evaluation (if using minibatching)
grad_theta_log_joint_prob = grad(self.log_joint_prob, 1)
grad_theta = vmap(grad_theta_log_joint_prob,
(0, None, None, None, None), 0)(g_samples, single_theta, self.x, self.interv_mask, subk_)
# stable computation of exp/log/divide and PyTree compatible
# sums over MC graph samples dimension to get MC gradient estimate of theta
# original PyTree shape of `single_theta`
log_numerator = tree_map(
lambda leaf_theta:
logsumexp(a=expand_by(logprobs_numerator, leaf_theta.ndim - 1), b=leaf_theta, axis=0, return_sign=True)[0],
grad_theta)
# original PyTree shape of `single_theta`
sign = tree_map(
lambda leaf_theta:
logsumexp(a=expand_by(logprobs_numerator, leaf_theta.ndim - 1), b=leaf_theta, axis=0, return_sign=True)[1],
grad_theta)
# []
log_denominator = logsumexp(logprobs_denominator, axis=0)
# original PyTree shape of `single_theta`
stable_grad = tree_map(
lambda sign_leaf_theta, log_leaf_theta:
(sign_leaf_theta * jnp.exp(log_leaf_theta - jnp.log(n_mc_numerator) - log_denominator + jnp.log(n_mc_denominator))),
sign, log_numerator)
return stable_grad
"""
Estimators for score d/dZ log p(Z)
"""
[docs] def constraint_gumbel(self, single_z, single_eps, t):
"""
Evaluates continuous acyclicity constraint using
Gumbel-softmax instead of Bernoulli samples
Args:
single_z (ndarray): single latent tensor ``[d, k, 2]``
single_eps (ndarray): i.i.d. Logistic noise of shape ``[d, d``] for Gumbel-softmax
t (int): step
Returns:
constraint value of shape ``[1,]``
"""
n_vars = single_z.shape[0]
G = self.particle_to_soft_graph(single_z, single_eps, t)
h = acyclic_constr_nograd(G, n_vars)
return h
[docs] def grad_constraint_gumbel(self, single_z, key, t):
"""
Reparameterization estimator for the gradient :math:`\\nabla_Z E_{p(G|Z)} [ h(G) ]`
where :math:`h` is the acyclicity constraint penalty function.
Since :math:`h` is differentiable w.r.t. :math:`G`, always uses
the Gumbel-softmax / concrete distribution reparameterization trick.
Args:
single_z (ndarray): single latent tensor ``[d, k, 2]``
key (ndarray): rng
t (int): step
Returns:
gradient of shape ``[d, k, 2]``
"""
n_vars = single_z.shape[0]
# [n_mc_samples, d, d]
eps = random.logistic(key, shape=(self.n_acyclicity_mc_samples, n_vars, n_vars))
# [n_mc_samples, d, k, 2]
mc_gradient_samples = vmap(grad(self.constraint_gumbel, 0), (None, 0, None), 0)(single_z, eps, t)
# [d, k, 2]
return mc_gradient_samples.mean(0)
[docs] def log_graph_prior_particle(self, single_z, t):
"""
Computes :math:`\\log p(G)` component of :math:`\\log p(Z)`,
i.e. not the contraint or Gaussian prior term, but the DAG belief.
The log prior :math:`\\log p(G)` is evaluated with
edge probabilities :math:`G_{\\alpha}(Z)` given :math:`Z`.
Args:
single_z (ndarray): single latent tensor ``[d, k, 2]``
t (int): step
Returns:
log prior graph probability`\\log p(G_{\\alpha}(Z))` of shape ``[1,]``
"""
# [d, d] # masking is done inside `edge_probs`
single_soft_g = self.edge_probs(single_z, t)
# [1, ]
return self.log_graph_prior(soft_g=single_soft_g)
[docs] def eltwise_grad_latent_prior(self, zs, subkeys, t):
"""
Computes batch of estimators for the score :math:`\\nabla_Z \\log p(Z)`
with
:math:`\\log p(Z) = - \\beta(t) E_{p(G|Z)} [h(G)] + \\log \\mathcal{N}(Z) + \\log f(Z)`
where :math:`h` is the acyclicity constraint and `f(Z)` is additional DAG prior factor
computed inside ``dibs.inference.DiBS.log_graph_prior_particle``.
Args:
zs (ndarray): single latent tensor ``[n_particles, d, k, 2]``
subkeys (ndarray): batch of rng keys ``[n_particles, ...]``
Returns:
batch of gradients of shape ``[n_particles, d, k, 2]``
"""
# log f(Z) term
# [d, k, 2], [1,] -> [d, k, 2]
grad_log_graph_prior_particle = grad(self.log_graph_prior_particle, 0)
# [n_particles, d, k, 2], [1,] -> [n_particles, d, k, 2]
grad_prior_z = vmap(grad_log_graph_prior_particle, (0, None), 0)(zs, t)
# constraint term
# [n_particles, d, k, 2], [n_particles,], [1,] -> [n_particles, d, k, 2]
eltwise_grad_constraint = vmap(self.grad_constraint_gumbel, (0, 0, None), 0)(zs, subkeys, t)
return - self.beta(t) * eltwise_grad_constraint \
- zs / (self.latent_prior_std ** 2.0) \
+ grad_prior_z
[docs] def visualize_callback(self, ipython=True, save_path=None):
"""Returns callback function for visualization of particles during inference updates
Args:
ipython (bool): set to ``True`` when running in a jupyter notebook
save_path (str): path to save plotted images to
Returns:
callback
"""
from dibs.utils.visualize import visualize
from dibs.graph_utils import elwise_acyclic_constr_nograd as constraint
if ipython:
from IPython import display
def callback(**kwargs):
zs = kwargs["zs"]
gs = kwargs["dibs"].particle_to_g_lim(zs)
probs = kwargs["dibs"].edge_probs(zs, kwargs["t"])
if ipython:
display.clear_output(wait=True)
visualize(probs, save_path=save_path, t=kwargs["t"], show=True)
print(
f'iteration {kwargs["t"]:6d}'
f' | alpha {self.alpha(kwargs["t"]):6.1f}'
f' | beta {self.beta(kwargs["t"]):6.1f}'
f' | #cyclic {(constraint(gs, self.n_vars) > 0).sum().item():3d}'
)
return
return callback