dibs.models package
Graph models
- class dibs.models.ErdosReniDAGDistribution(n_vars, n_edges_per_node=2)[source]
Randomly oriented Erdos-Reni random graph model with i.i.d. edge probability. The pmf is defined as
\(p(G) \propto p^e (1-p)^{\binom{d}{2} - e}\)
where \(e\) denotes the total number of edges in \(G\) and \(p\) is chosen to satisfy the requirement of sampling
n_edges_per_node
edges per node in expectation.- Parameters
n_vars (int) – number of variables in DAG
n_edges_per_node (int) – number of edges sampled per variable in expectation
- sample_G(key, return_mat=False)[source]
Samples DAG
- Parameters
key (ndarray) – rng
return_mat (bool) – if
True
, returns adjacency matrix of shape[n_vars, n_vars]
- Returns
DAG
- Return type
iGraph.graph
/jnp.array
- unnormalized_log_prob(*, g)[source]
Computes \(\log p(G)\) up the normalization constant
- Parameters
g (iGraph.graph) – graph
- Returns
unnormalized log probability of \(G\)
- unnormalized_log_prob_single(*, g, j)[source]
Computes \(\log p(G_j)\) up the normalization constant
- Parameters
g (iGraph.graph) – graph
j (int) – node index:
- Returns
unnormalized log probability of node family of \(j\)
- unnormalized_log_prob_soft(*, soft_g)[source]
Computes \(\log p(G)\) up the normalization constant where \(G\) is the matrix of edge probabilities
- Parameters
soft_g (ndarray) – graph adjacency matrix, where entries may be probabilities and not necessarily 0 or 1
- Returns
unnormalized log probability corresponding to edge probabilities in \(G\)
- class dibs.models.ScaleFreeDAGDistribution(n_vars, verbose=False, n_edges_per_node=2)[source]
Randomly-oriented scale-free random graph with power-law degree distribution. The pmf is defined as
\(p(G) \propto \prod_j (1 + \text{deg}(j))^{-3}\)
where \(\text{deg}(j)\) denotes the in-degree of node \(j\)
- Parameters
n_vars (int) – number of variables in DAG
n_edges_per_node (int) – number of edges sampled per variable
- sample_G(key, return_mat=False)[source]
Samples DAG
- Parameters
key (ndarray) – rng
return_mat (bool) – if
True
, returns adjacency matrix of shape[n_vars, n_vars]
- Returns
DAG
- Return type
iGraph.graph
/jnp.array
- unnormalized_log_prob(*, g)[source]
Computes \(\log p(G)\) up the normalization constant
- Parameters
g (iGraph.graph) – graph
- Returns
unnormalized log probability of \(G\)
- unnormalized_log_prob_single(*, g, j)[source]
Computes \(\log p(G_j)\) up the normalization constant
- Parameters
g (iGraph.graph) – graph
j (int) – node index:
- Returns
unnormalized log probability of node family of \(j\)
- unnormalized_log_prob_soft(*, soft_g)[source]
Computes \(\log p(G)\) up the normalization constant where \(G\) is the matrix of edge probabilities
- Parameters
soft_g (ndarray) – graph adjacency matrix, where entries may be probabilities and not necessarily 0 or 1
- Returns
unnormalized log probability corresponding to edge probabilities in \(G\)
Observational models
- class dibs.models.BGe(*, n_vars, mean_obs=None, alpha_mu=None, alpha_lambd=None)[source]
Linear Gaussian BN model corresponding to linear structural equation model (SEM) with additive Gaussian noise. Uses Normal-Wishart conjugate parameter prior to allow for closed-form marginal likelihood \(\log p(D | G)\) and thus allows inference of the marginal posterior \(p(G | D)\)
For details on the closed-form expression, refer to
Geiger et al. (2002): https://projecteuclid.org/download/pdf_1/euclid.aos/1035844981
Kuipers et al. (2014): https://projecteuclid.org/download/suppdf_1/euclid.aos/1407420013
The default arguments imply commonly-used default hyperparameters for mean and precision of the Normal-Wishart and assume a diagonal parameter matrix \(T\). Inspiration for the implementation was drawn from https://bitbucket.org/jamescussens/pygobnilp/src/master/pygobnilp/scoring.py
This implementation uses properties of the determinant to make the computation of the marginal likelihood
jax.jit
-compilable andjax.grad
-differentiable by remaining well-defined for soft relaxations of the graph.- Parameters
n_vars (int) – number of variables (nodes in the graph)
mean_obs (ndarray, optional) – mean parameter of Normal
alpha_mu (float, optional) – precision parameter of Normal
alpha_lambd (float, optional) – degrees of freedom parameter of Wishart
- interventional_log_marginal_prob(g, _, x, interv_targets, rng)[source]
Computes interventional marginal likelihood \(\log p(D | G)\) in closed-form;
jax.jit
-compatibleTo unify the function signatures for the marginal and joint inference classes
MarginalDiBS
andJointDiBS
, this marginal likelihood is defined with dummytheta
inputs as_
, i.e., like a joint likelihood- Parameters
g (ndarray) – graph adjacency matrix of shape
[n_vars, n_vars]
. Entries must be binary and of typejnp.int32
_ –
x (ndarray) – observational data of shape
[n_observations, n_vars]
interv_targets (ndarray) – indicator mask of interventions of shape
[n_observations, n_vars]
rng (ndarray) – rng; skeleton for minibatching (TBD)
- Returns
BGe score of shape
[1,]
- log_marginal_likelihood(*, g, x, interv_targets)[source]
Computes BGe marginal likelihood \(\log p(D | G)\) in closed-form;
jax.jit
-compatible- Parameters
g (ndarray) – adjacency matrix of shape
[d, d]
x (ndarray) – observations of shape
[N, d]
interv_targets (ndarray) – boolean mask of shape
[N, d]
of whether or not a node was intervened upon in a given sample. Intervened nodes are ignored in likelihood computation
- Returns
BGe Score
- class dibs.models.LinearGaussian(*, n_vars, obs_noise=0.1, mean_edge=0.0, sig_edge=1.0, min_edge=0.5)[source]
Linear Gaussian BN model corresponding to linear structural equation model (SEM) with additive Gaussian noise.
Each variable distributed as Gaussian with mean being the linear combination of its parents weighted by a Gaussian parameter vector (i.e., with Gaussian-valued edges). The noise variance at each node is equal by default, which implies the causal structure is identifiable.
- Parameters
n_vars (int) – number of variables (nodes in the graph)
obs_noise (float, optional) – variance of additive observation noise at nodes
mean_edge (float, optional) – mean of Gaussian edge weight
sig_edge (float, optional) – std dev of Gaussian edge weight
min_edge (float, optional) – minimum linear effect of parent on child
- get_theta_shape(*, n_vars)[source]
Returns tree shape of the parameters of the linear model
- Parameters
n_vars (int) – number of variables in model
- Returns
PyTree of parameter shape
- interventional_log_joint_prob(g, theta, x, interv_targets, rng)[source]
Computes interventional joint likelihood \(\log p(\Theta, D | G)\)
- Parameters
g (ndarray) – graph adjacency matrix of shape
[n_vars, n_vars]
theta (ndarray) – parameter matrix of shape
[n_vars, n_vars]
x (ndarray) – observational data of shape
[n_observations, n_vars]
interv_targets (ndarray) – indicator mask of interventions of shape
[n_observations, n_vars]
rng (ndarray) – rng; skeleton for minibatching (TBD)
- Returns
log prob of shape
[1,]
- log_likelihood(*, x, theta, g, interv_targets)[source]
Computes likelihood \(p(D | G, \Theta)\). In this model, the noise per observation and node is additive and Gaussian.
- Parameters
x (ndarray) – observations of shape
[n_observations, n_vars]
theta (ndarray) – parameters of shape
[n_vars, n_vars]
g (ndarray) – graph adjacency matrix of shape
[n_vars, n_vars]
interv_targets (ndarray) – binary intervention indicator vector of shape
[n_observations, n_vars]
- Returns
log prob
- log_prob_parameters(*, theta, g)[source]
Computes parameter prior \(\log p(\Theta | G)\) In this model, the parameter prior is Gaussian.
- Parameters
theta (ndarray) – parameter matrix of shape
[n_vars, n_vars]
g (ndarray) – graph adjacency matrix of shape
[n_vars, n_vars]
- Returns
log prob
- sample_obs(*, key, n_samples, g, theta, toporder=None, interv=None)[source]
Samples
n_samples
observations given graphg
and parameterstheta
- Parameters
key (ndarray) – rng
n_samples (int) – number of samples
g (igraph.Graph) – graph
theta (Any) – parameters
interv (dict) – intervention specification of the form
{intervened node : clamp value}
- Returns
observation matrix of shape
[n_samples, n_vars]
- sample_parameters(*, key, n_vars, n_particles=0, batch_size=0)[source]
Samples batch of random parameters given dimensions of graph from \(p(\Theta | G)\)
- Parameters
key (ndarray) – rng
n_vars (int) – number of variables in BN
n_particles (int) – number of parameter particles sampled
batch_size (int) – number of batches of particles being sampled
- Returns
Parameters
theta
of shape[batch_size, n_particles, n_vars, n_vars]
, dropping dimensions equal to 0
- class dibs.models.DenseNonlinearGaussian(*, n_vars, hidden_layers, obs_noise=0.1, sig_param=1.0, activation='relu', bias=True)[source]
Nonlinear Gaussian BN model corresponding to a nonlinaer structural equation model (SEM) with additive Gaussian noise.
Each variable distributed as Gaussian with mean parameterized by a dense neural network (MLP) whose weights and biases are sampled from a Gaussian prior. The noise variance at each node is equal by default.
Refer to http://proceedings.mlr.press/v108/zheng20a/zheng20a.pdf
- Parameters
n_vars (int) – number of variables (nodes in the graph)
hidden_layers (tuple) – list of integers specifying the number of layers as well as their widths. For example:
[8, 8]
would correspond to 2 hidden layers with 8 neuronsobs_noise (float, optional) – variance of additive observation noise at nodes
sig_param (float, optional) – std dev of Gaussian parameter prior
activation (str, optional) – identifier for activation function. Choices:
sigmoid
,tanh
,relu
,leakyrelu
- get_theta_shape(*, n_vars)[source]
Returns tree shape of the parameters of the neural networks
- Parameters
n_vars (int) – number of variables in model
- Returns
PyTree of parameter shape
- interventional_log_joint_prob(g, theta, x, interv_targets, rng)[source]
Computes interventional joint likelihood \(\log p(\Theta, D | G)\)
- Parameters
g (ndarray) – graph adjacency matrix of shape
[n_vars, n_vars]
theta (Any) – parameter PyTree
x (ndarray) – observational data of shape
[n_observations, n_vars]
interv_targets (ndarray) – indicator mask of interventions of shape
[n_observations, n_vars]
rng (ndarray) – rng; skeleton for minibatching (TBD)
- Returns
log prob of shape
[1,]
- log_likelihood(*, x, theta, g, interv_targets)[source]
Computes likelihood \(p(D | G, \Theta)\). In this model, the noise per observation and node is additive and Gaussian.
- Parameters
x (ndarray) – observations of shape
[n_observations, n_vars]
theta (Any) – parameters PyTree
g (ndarray) – graph adjacency matrix of shape
[n_vars, n_vars]
interv_targets (ndarray) – binary intervention indicator vector of shape
[n_vars, ]
- Returns
log prob
- log_prob_parameters(*, theta, g)[source]
Computes parameter prior \(\log p(\Theta | G)\) In this model, the prior over weights and biases is zero-centered Gaussian.
- Parameters
theta (Any) – parameter pytree
g (ndarray) – graph adjacency matrix of shape
[n_vars, n_vars]
- Returns
log prob
- sample_obs(*, key, n_samples, g, theta, toporder=None, interv=None)[source]
Samples
n_samples
observations given graphg
and parameterstheta
by doing single forward passes in topological order- Parameters
key (ndarray) – rng
n_samples (int) – number of samples
g (igraph.Graph) – graph
theta (Any) – parameters
interv (dict) – intervention specification of the form
{intervened node : clamp value}
- Returns
observation matrix of shape
[n_samples, n_vars]
- sample_parameters(*, key, n_vars, n_particles=0, batch_size=0)[source]
Samples batch of random parameters given dimensions of graph from \(p(\Theta | G)\)
- Parameters
key (ndarray) – rng
n_vars (int) – number of variables in BN
n_particles (int) – number of parameter particles sampled
batch_size (int) – number of batches of particles being sampled
- Returns
Parameter PyTree with leading dimension(s)
batch_size
and/orn_particles
, dropping either dimension when equal to 0