dibs.inference package
DiBS
- class dibs.inference.DiBS(*, x, interv_mask, log_graph_prior, log_joint_prob, alpha_linear=0.05, beta_linear=1.0, tau=1.0, n_grad_mc_samples=128, n_acyclicity_mc_samples=32, grad_estimator_z='reparam', score_function_baseline=0.0, latent_prior_std=None, verbose=False)[source]
This class implements the backbone for DiBS, i.e. all gradient estimators and sampling components. Any inference method in the DiBS framework should inherit from this class.
- Parameters
x (ndarray) – matrix of shape
[n_observations, n_vars]
of i.i.d. observations of the variablesinterv_mask (ndarray) – binary matrix of shape
[n_observations, n_vars]
indicating whether a given variable was intervened upon in a given sample (intervention = 1, no intervention = 0)log_graph_prior (callable) – function implementing prior \(\log p(G)\) of soft adjacency matrix of edge probabilities. For example:
unnormalized_log_prob_soft()
or usually bound in e.g.log_graph_prior()
log_joint_prob (callable) – function implementing joint likelihood \(\log p(\Theta, D | G)\) of parameters and observations given the discrete graph adjacency matrix For example:
dibs.models.LinearGaussian.interventional_log_joint_prob()
. When inferring the marginal posterior \(p(G | D)\) via a closed-form marginal likelihood \(\log p(D | G)\), the same function signature has to be satisfied (simply ignoring \(\Theta\))alpha_linear (float) – slope of of linear schedule for inverse temperature \(\alpha\) of sigmoid in latent graph model \(p(G | Z)\)
beta_linear (float) – slope of of linear schedule for inverse temperature \(\beta\) of constraint penalty in latent prio \(p(Z)\)
tau (float) – constant Gumbel-softmax temperature parameter
n_grad_mc_samples (int) – number of Monte Carlo samples in gradient estimator for likelihood term \(p(\Theta, D | G)\)
n_acyclicity_mc_samples (int) – number of Monte Carlo samples in gradient estimator for acyclicity constraint
grad_estimator_z (str) – gradient estimator \(\nabla_Z\) of expectation over \(p(G | Z)\); choices:
score
orreparam
score_function_baseline (float) – scale of additive baseline in score function (REINFORCE) estimator;
score_function_baseline == 0.0
corresponds to not using a baselinelatent_prior_std (float) – standard deviation of Gaussian prior over \(Z\); defaults to
1/sqrt(k)
- constraint_gumbel(single_z, single_eps, t)[source]
Evaluates continuous acyclicity constraint using Gumbel-softmax instead of Bernoulli samples
- Parameters
single_z (ndarray) – single latent tensor
[d, k, 2]
single_eps (ndarray) – i.i.d. Logistic noise of shape
[d, d
] for Gumbel-softmaxt (int) – step
- Returns
constraint value of shape
[1,]
- edge_log_probs(z, t)[source]
Edge log probabilities encoded by latent representation
- Parameters
z (ndarray) – latent tensors \(Z\)
[..., d, k, 2]
t (int) – step
- Returns
tuple of tensors
[..., d, d], [..., d, d]
corresponding tolog(p)
andlog(1-p)
- edge_probs(z, t)[source]
Edge probabilities encoded by latent representation
- Parameters
z (ndarray) – latent tensors \(Z\)
[..., d, k, 2]
t (int) – step
- Returns
edge probabilities of shape
[..., d, d]
- eltwise_grad_latent_log_prob(gs, single_z, t)[source]
Gradient of log likelihood of generative graph model w.r.t. \(Z\) i.e. \(\nabla_Z \log p(G | Z)\) Batched over samples of \(G\) given a single \(Z\).
- Parameters
gs (ndarray) – batch of graph matrices
[n_graphs, d, d]
single_z (ndarray) – latent variable
[d, k, 2]
t (int) – step
- Returns
batch of gradients of shape
[n_graphs, d, k, 2]
- eltwise_grad_latent_prior(zs, subkeys, t)[source]
Computes batch of estimators for the score \(\nabla_Z \log p(Z)\) with
\(\log p(Z) = - \beta(t) E_{p(G|Z)} [h(G)] + \log \mathcal{N}(Z) + \log f(Z)\)
where \(h\) is the acyclicity constraint and f(Z) is additional DAG prior factor computed inside
dibs.inference.DiBS.log_graph_prior_particle
.- Parameters
zs (ndarray) – single latent tensor
[n_particles, d, k, 2]
subkeys (ndarray) – batch of rng keys
[n_particles, ...]
- Returns
batch of gradients of shape
[n_particles, d, k, 2]
- eltwise_grad_theta_likelihood(zs, thetas, t, subkeys)[source]
Computes batch of estimators for the score \(\nabla_{\Theta} \log p(\Theta, D | Z)\), i.e. w.r.t the conditional distribution parameters. Uses the same \(G \sim p(G | Z)\) samples for expectations in numerator and denominator.
This does not use \(\nabla_G \log p(\Theta, D | G)\) and is hence applicable when the gradient w.r.t. the adjacency matrix is not defined (as e.g. for the BGe score). Analogous to
eltwise_grad_z_likelihood
but gradient w.r.t \(\Theta\) instead of \(Z\)- Parameters
zs (ndarray) – batch of latent tensors Z of shape
[n_particles, d, k, 2]
thetas (Any) – batch of parameter PyTree with
n_mc_samples
as leading dim
- Returns
batch of gradients in form of
thetas
PyTree withn_particles
as leading dim
- eltwise_grad_z_likelihood(zs, thetas, baselines, t, subkeys)[source]
Computes batch of estimators for score \(\nabla_Z \log p(\Theta, D | Z)\) Selects corresponding estimator used for the term \(\nabla_Z E_{p(G|Z)}[ p(\Theta, D | G) ]\) and executes it in batch.
- Parameters
zs (ndarray) – batch of latent tensors \(Z\)
[n_particles, d, k, 2]
thetas (Any) – batch of parameters PyTree with
n_particles
as leading dimbaselines (ndarray) – array of score function baseline values of shape
[n_particles, ]
- Returns
tuple batch of (gradient estimates, baselines) of shapes
[n_particles, d, k, 2], [n_particles, ]
- eltwise_log_joint_prob(gs, single_theta, rng)[source]
Joint likelihood \(\log p(\Theta, D | G)\) batched over samples of \(G\)
- Parameters
gs (ndarray) – batch of graphs
[n_graphs, d, d]
single_theta (Any) – single parameter PyTree
rng (ndarray) – for mini-batching
x
potentially
- Returns
batch of logprobs of shape
[n_graphs, ]
- grad_constraint_gumbel(single_z, key, t)[source]
Reparameterization estimator for the gradient \(\nabla_Z E_{p(G|Z)} [ h(G) ]\) where \(h\) is the acyclicity constraint penalty function.
Since \(h\) is differentiable w.r.t. \(G\), always uses the Gumbel-softmax / concrete distribution reparameterization trick.
- Parameters
single_z (ndarray) – single latent tensor
[d, k, 2]
key (ndarray) – rng
t (int) – step
- Returns
gradient of shape
[d, k, 2]
- grad_theta_likelihood(single_z, single_theta, t, subk)[source]
Computes Monte Carlo estimator for the score \(\nabla_{\Theta} \log p(\Theta, D | Z)\)
Uses hard samples of \(G\), but a soft reparameterization like for \(\nabla_Z\) is also possible. Uses the same \(G \sim p(G | Z)\) samples for expectations in numerator and denominator.
- Parameters
single_z (ndarray) – single latent tensor
[d, k, 2]
single_theta (Any) – single parameter PyTree
t (int) – step
subk (ndarray) – rng key
- Returns
parameter gradient PyTree
- grad_z_likelihood_gumbel(single_z, single_theta, single_sf_baseline, t, subk)[source]
Reparameterization estimator for the score \(\nabla_Z \log p(\Theta, D | Z)\) sing the Gumbel-softmax / concrete distribution reparameterization trick. Uses the same \(G \sim p(G | Z)\) samples for expectations in numerator and denominator.
This does require a well-defined gradient \(\nabla_G \log p(\Theta, D | G)\) and is hence not applicable when the gradient w.r.t. the adjacency matrix is not defined for Gumbel-relaxations of the discrete adjacency matrix. Any (marginal) likelihood expressible as a function of
g[:, j]
andtheta
, e.g. using the vector of (possibly soft) parent indicators as a mask, satisfies this.Examples are:
dibs.models.LinearGaussian
anddibs.models.DenseNonlinearGaussian
See also e.g. http://proceedings.mlr.press/v108/zheng20a/zheng20a.pdf- Parameters
single_z (ndarray) – single latent tensor
[d, k, 2]
single_theta (Any) – single parameter PyTree
single_sf_baseline (ndarray) –
[1, ]
t (int) – step
subk (ndarray) – rng key
- Returns
tuple of gradient, baseline
[d, k, 2], [1, ]
- grad_z_likelihood_score_function(single_z, single_theta, single_sf_baseline, t, subk)[source]
Score function estimator (aka REINFORCE) for the score \(\nabla_Z \log p(\Theta, D | Z)\) Uses the same \(G \sim p(G | Z)\) samples for expectations in numerator and denominator.
This does not use \(\nabla_G \log p(\Theta, D | G)\) and is hence applicable when the gradient w.r.t. the adjacency matrix is not defined (as e.g. for the BGe score).
- Parameters
single_z (ndarray) – single latent tensor
[d, k, 2]
single_theta (Any) – single parameter PyTree
single_sf_baseline (ndarray) –
[1, ]
t (int) – step
subk (ndarray) – rng key
- Returns
tuple of gradient, baseline
[d, k, 2], [1, ]
- latent_log_prob(single_g, single_z, t)[source]
Log likelihood of generative graph model
- Parameters
single_g (ndarray) – single graph adjacency matrix
[d, d]
single_z (ndarray) – single latent tensor
[d, k, 2]
t (int) – step
- Returns
log likelihood \(log p(G | Z)\) of shape
[1,]
- log_graph_prior_particle(single_z, t)[source]
Computes \(\log p(G)\) component of \(\log p(Z)\), i.e. not the contraint or Gaussian prior term, but the DAG belief.
The log prior \(\log p(G)\) is evaluated with edge probabilities \(G_{\alpha}(Z)\) given \(Z\).
- Parameters
single_z (ndarray) – single latent tensor
[d, k, 2]
t (int) – step
- Returns
log prior graph probability`log p(G_{alpha}(Z))` of shape
[1,]
- log_joint_prob_soft(single_z, single_theta, eps, t, subk)[source]
This is the composition of \(\log p(\Theta, D | G) \) (Gumbel-softmax graph sample given \(Z\))
- Parameters
single_z (ndarray) – single latent tensor
[d, k, 2]
single_theta (Any) – single parameter PyTree
eps (ndarray) – i.i.d Logistic noise of shape
[d, d]
t (int) – step
subk (ndarray) – rng key
- Returns
logprob of shape
[1, ]
- particle_to_g_lim(z)[source]
Returns \(G\) corresponding to \(\alpha = \infty\) for particles z
- Parameters
z (ndarray) – latent variables
[..., d, k, 2]
- Returns
graph adjacency matrices of shape
[..., d, d]
- particle_to_hard_graph(z, eps, t)[source]
Bernoulli sample of \(G\) using probabilities implied by latent
z
- Parameters
z (ndarray) – a single latent tensor \(Z\) of shape
[d, k, 2]
eps (ndarray) – random i.i.d. Logistic(0,1) noise of shape
[d, d]
t (int) – step
- Returns
Gumbel-max (hard) sample of adjacency matrix
[d, d]
- particle_to_soft_graph(z, eps, t)[source]
Gumbel-softmax / concrete distribution using Logistic(0,1) samples
eps
- Parameters
z (ndarray) – a single latent tensor \(Z\) of shape
[d, k, 2]`
eps (ndarray) – random i.i.d. Logistic(0,1) noise of shape
[d, d]
t (int) – step
- Returns
Gumbel-softmax sample of adjacency matrix [d, d]
SVGD
- class dibs.inference.MarginalDiBS(*, x, graph_model, likelihood_model, interv_mask=None, kernel=<class 'dibs.kernel.AdditiveFrobeniusSEKernel'>, kernel_param=None, optimizer='rmsprop', optimizer_param=None, alpha_linear=1.0, beta_linear=1.0, tau=1.0, n_grad_mc_samples=128, n_acyclicity_mc_samples=32, grad_estimator_z='score', score_function_baseline=0.0, latent_prior_std=None, verbose=False)[source]
Bases:
dibs.inference.dibs.DiBS
This class implements Stein Variational Gradient Descent (SVGD) (Liu and Wang, 2016) for DiBS inference (Lorch et al., 2021) of the marginal DAG posterior \(p(G | D)\). For joint inference of \(p(G, \Theta | D)\), use the analogous class
JointDiBS
.An SVGD update of tensor \(v\) is defined as
\(\phi(v) \propto \sum_{u} k(v, u) \nabla_u \log p(u) + \nabla_u k(u, v)\)
- Parameters
x (ndarray) – observations of shape
[n_observations, n_vars]
interv_mask (ndarray, optional) – binary matrix of shape
[n_observations, n_vars]
indicating whether a given variable was intervened upon in a given sample (intervention = 1, no intervention = 0)graph_model – Model defining the prior \(\log p(G)\) underlying the inferred posterior. Object has to implement one method:
unnormalized_log_prob_soft
Example:ErdosReniDAGDistribution
likelihood_model – Model defining the marginal likelihood \(\log p(D | G)\) underlying the inferred posterior. Object has to implement one method:
interventional_log_marginal_prob
Example:BGe
kernel – Class of kernel. Has to implement the method
eval(u, v)
. Example:AdditiveFrobeniusSEKernel
kernel_param (dict) – kwargs to instantiate
kernel
optimizer (str) – optimizer identifier
optimizer_param (dict) – kwargs to instantiate
optimizer
alpha_linear (float) – slope of of linear schedule for inverse temperature \(\alpha\) of sigmoid in latent graph model \(p(G | Z)\)
beta_linear (float) – slope of of linear schedule for inverse temperature \(\beta\) of constraint penalty in latent prior \(p(Z)\)
tau (float) – constant Gumbel-softmax temperature parameter
n_grad_mc_samples (int) – number of Monte Carlo samples in gradient estimator for likelihood term \(p(\Theta, D | G)\)
n_acyclicity_mc_samples (int) – number of Monte Carlo samples in gradient estimator for acyclicity constraint
grad_estimator_z (str) – gradient estimator \(\nabla_Z\) of expectation over \(p(G | Z)\); choices:
score
orreparam
score_function_baseline (float) – scale of additive baseline in score function (REINFORCE) estimator;
score_function_baseline == 0.0
corresponds to not using a baselinelatent_prior_std (float) – standard deviation of Gaussian prior over \(Z\); defaults to
1/sqrt(k)
- get_empirical(g)[source]
Converts batch of binary (adjacency) matrices into empirical particle distribution where mixture weights correspond to counts/occurrences
- Parameters
g (ndarray) – batch of graph samples
[n_particles, d, d]
with binary values- Returns
particle distribution of graph samples and associated log probabilities
- Return type
- get_mixture(g)[source]
Converts batch of binary (adjacency) matrices into mixture particle distribution, where mixture weights correspond to unnormalized target (i.e. posterior) probabilities
- Parameters
g (ndarray) – batch of graph samples
[n_particles, d, d]
with binary values- Returns
particle distribution of graph samples and associated log probabilities
- Return type
- sample(*, key, n_particles, steps, n_dim_particles=None, callback=None, callback_every=None)[source]
Use SVGD with DiBS to sample
n_particles
particles \(G\) from the marginal posterior \(p(G | D)\) as defined by the BN modelself.inference_model
- Parameters
key (ndarray) – prng key
n_particles (int) – number of particles to sample
steps (int) – number of SVGD steps performed
n_dim_particles (int) – latent dimensionality \(k\) of particles \(Z = \{ U, V \}\) with \(U, V \in \mathbb{R}^{k \times d}\). Default is
n_vars
callback – function to be called every
callback_every
steps of SVGD.callback_every – if
None
,callback
is only called after particle updates have finished
- Returns
batch of samples \(G \sim p(G | D)\) of shape
[n_particles, n_vars, n_vars]
- class dibs.inference.JointDiBS(*, x, graph_model, likelihood_model, interv_mask=None, kernel=<class 'dibs.kernel.JointAdditiveFrobeniusSEKernel'>, kernel_param=None, optimizer='rmsprop', optimizer_param=None, alpha_linear=0.05, beta_linear=1.0, tau=1.0, n_grad_mc_samples=128, n_acyclicity_mc_samples=32, grad_estimator_z='reparam', score_function_baseline=0.0, latent_prior_std=None, verbose=False)[source]
Bases:
dibs.inference.dibs.DiBS
This class implements Stein Variational Gradient Descent (SVGD) (Liu and Wang, 2016) for DiBS inference (Lorch et al., 2021) of the marginal DAG posterior \(p(G | D)\). For marginal inference of \(p(G | D)\), use the analogous class
MarginalDiBS
.An SVGD update of tensor \(v\) is defined as
\(\phi(v) \propto \sum_{u} k(v, u) \nabla_u \log p(u) + \nabla_u k(u, v)\)
- Parameters
x (ndarray) – observations of shape
[n_observations, n_vars]
interv_mask (ndarray, optional) – binary matrix of shape
[n_observations, n_vars]
indicating whether a given variable was intervened upon in a given sample (intervention = 1, no intervention = 0)graph_model – Model defining the prior \(\log p(G)\) underlying the inferred posterior. Object has to implement one method:
unnormalized_log_prob_soft
Example:ErdosReniDAGDistribution
likelihood_model – Model defining the joint likelihood \(\log p(\Theta, D | G) = \log p(\Theta | G) + \log p(D | G, \Theta)\) underlying the inferred posterior. Object has to implement one method:
interventional_log_joint_prob
Example:LinearGaussian
kernel – Class of kernel. Has to implement the method
eval(u, v)
. Example:JointAdditiveFrobeniusSEKernel
kernel_param (dict) – kwargs to instantiate
kernel
optimizer (str) – optimizer identifier
optimizer_param (dict) – kwargs to instantiate
optimizer
alpha_linear (float) – slope of of linear schedule for inverse temperature \(\alpha\) of sigmoid in latent graph model \(p(G | Z)\)
beta_linear (float) – slope of of linear schedule for inverse temperature \(\beta\) of constraint penalty in latent prior \(p(Z)\)
tau (float) – constant Gumbel-softmax temperature parameter
n_grad_mc_samples (int) – number of Monte Carlo samples in gradient estimator for likelihood term \(p(\Theta, D | G)\)
n_acyclicity_mc_samples (int) – number of Monte Carlo samples in gradient estimator for acyclicity constraint
grad_estimator_z (str) – gradient estimator \(\nabla_Z\) of expectation over \(p(G | Z)\); choices:
score
orreparam
score_function_baseline (float) – scale of additive baseline in score function (REINFORCE) estimator;
score_function_baseline == 0.0
corresponds to not using a baselinelatent_prior_std (float) – standard deviation of Gaussian prior over \(Z\); defaults to
1/sqrt(k)
- get_empirical(g, theta)[source]
Converts batch of binary (adjacency) matrices and parameters into empirical particle distribution where mixture weights correspond to counts/occurrences
- Parameters
g (ndarray) – batch of graph samples
[n_particles, d, d]
with binary valuestheta (Any) – PyTree with leading dim
n_particles
- Returns
particle distribution of graph and parameter samples and associated log probabilities
- Return type
- get_mixture(g, theta)[source]
Converts batch of binary (adjacency) matrices and particles into mixture particle distribution, where mixture weights correspond to unnormalized target (i.e. posterior) probabilities
- Parameters
g (ndarray) – batch of graph samples
[n_particles, d, d]
with binary valuestheta (Any) – PyTree with leading dim
n_particles
- Returns
particle distribution of graph and parameter samples and associated log probabilities
- Return type
- sample(*, key, n_particles, steps, n_dim_particles=None, callback=None, callback_every=None)[source]
Use SVGD with DiBS to sample
n_particles
particles \((G, \Theta)\) from the joint posterior \(p(G, \Theta | D)\) as defined by the BN modelself.likelihood_model
- Parameters
key (ndarray) – prng key
n_particles (int) – number of particles to sample
steps (int) – number of SVGD steps performed
n_dim_particles (int) – latent dimensionality \(k\) of particles \(Z = \{ U, V \}\) with \(U, V \in \mathbb{R}^{k \times d}\). Default is
n_vars
callback – function to be called every
callback_every
steps of SVGD.callback_every – if
None
,callback
is only called after particle updates have finished
- Returns
batch of samples \(G, \Theta \sim p(G, \Theta | D)\)
- Return type
tuple of shape (
[n_particles, n_vars, n_vars]
,PyTree
) wherePyTree
has leading dimensionn_particles