import jax.numpy as jnp
from jax import random
from dibs.models.graph import ErdosReniDAGDistribution, ScaleFreeDAGDistribution, UniformDAGDistributionRejection
from dibs.graph_utils import graph_to_mat
from dibs.models import LinearGaussian, BGe, DenseNonlinearGaussian
from typing import Any, NamedTuple
[docs]class Data(NamedTuple):
""" NamedTuple for structuring simulated synthetic data and their ground
truth generative model
Args:
passed_key (ndarray): ``jax.random`` key passed *into* the function generating this object
n_vars (int): number of variables in model
n_observations (int): number of observations in ``x`` and used to perform inference
n_ho_observations (int): number of held-out observations in ``x_ho``
and elements of ``x_interv`` used for evaluation
g (ndarray): ground truth DAG
theta (Any): ground truth parameters
x (ndarray): i.i.d observations from the model of shape ``[n_observations, n_vars]``
x_ho (ndarray): i.i.d observations from the model of shape ``[n_ho_observations, n_vars]``
x_interv (list): list of (interv dict, i.i.d observations)
"""
passed_key: Any
n_vars: int
n_observations: int
n_ho_observations: int
g: Any
theta: Any
x: Any
x_ho: Any
x_interv: Any
[docs]def make_synthetic_bayes_net(*,
key,
n_vars,
graph_model,
generative_model,
n_observations=100,
n_ho_observations=100,
n_intervention_sets=10,
perc_intervened=0.1,
):
"""
Returns an instance of :class:`~dibs.metrics.Target` for evaluation of a method on
a ground truth synthetic causal Bayesian network
Args:
key (ndarray): rng key
n_vars (int): number of variables
graph_model (Any): graph model object. For example: :class:`~dibs.models.ErdosReniDAGDistribution`
generative_model (Any): BN model object for generating the observations. For example: :class:`~dibs.models.LinearGaussian`
n_observations (int): number of observations generated for posterior inference
n_ho_observations (int): number of held-out observations generated for evaluation
n_intervention_sets (int): number of different interventions considered overall
for generating interventional data
perc_intervened (float): percentage of nodes intervened upon (clipped to 0) in
an intervention.
Returns:
:class:`~dibs.target.Data`:
synthetic ground truth generative DAG and parameters as well observations sampled from the model
"""
# remember random key
passed_key = key.copy()
# generate ground truth observations
key, subk = random.split(key)
g_gt = graph_model.sample_G(subk)
g_gt_mat = jnp.array(graph_to_mat(g_gt))
key, subk = random.split(key)
theta = generative_model.sample_parameters(key=subk, n_vars=n_vars)
key, subk = random.split(key)
x = generative_model.sample_obs(key=subk, n_samples=n_observations, g=g_gt, theta=theta)
key, subk = random.split(key)
x_ho = generative_model.sample_obs(key=subk, n_samples=n_ho_observations, g=g_gt, theta=theta)
# 10 random 0-clamp interventions where `perc_interv` % of nodes are intervened on
# list of (interv dict, x)
x_interv = []
for idx in range(n_intervention_sets):
# random intervention
key, subk = random.split(key)
n_interv = jnp.ceil(n_vars * perc_intervened).astype(jnp.int32)
interv_targets = random.choice(subk, n_vars, shape=(n_interv,), replace=False)
interv = {int(k): 0.0 for k in interv_targets}
# observations from p(x | theta, G, interv) [n_samples, n_vars]
key, subk = random.split(key)
x_interv_ = generative_model.sample_obs(key=subk, n_samples=n_observations, g=g_gt, theta=theta, interv=interv)
x_interv.append((interv, x_interv_))
# return and save generated target object
data = Data(
passed_key=passed_key,
n_vars=n_vars,
n_observations=n_observations,
n_ho_observations=n_ho_observations,
g=g_gt_mat,
theta=theta,
x=x,
x_ho=x_ho,
x_interv=x_interv,
)
return data
[docs]def make_graph_model(*, n_vars, graph_prior_str, edges_per_node=2):
"""
Instantiates graph model
Args:
n_vars (int): number of variables in graph
graph_prior_str (str): specifier for random graph model; choices: ``er``, ``sf``
edges_per_node (int): number of edges per node (in expectation when applicable)
Returns:
Object representing graph model. For example :class:`~dibs.models.ErdosReniDAGDistribution` or :class:`~dibs.models.ScaleFreeDAGDistribution`
"""
if graph_prior_str == 'er':
graph_model = ErdosReniDAGDistribution(
n_vars=n_vars,
n_edges_per_node=edges_per_node)
elif graph_prior_str == 'sf':
graph_model = ScaleFreeDAGDistribution(
n_vars=n_vars,
n_edges_per_node=edges_per_node)
else:
assert n_vars <= 5, "Naive uniform DAG sampling only possible up to 5 nodes"
graph_model = UniformDAGDistributionRejection(
n_vars=n_vars)
return graph_model
[docs]def make_linear_gaussian_equivalent_model(*, key, n_vars=20, graph_prior_str='sf',
bge_mean_obs=None, bge_alpha_mu=None, bge_alpha_lambd=None,
obs_noise=0.1, mean_edge=0.0, sig_edge=1.0, min_edge=0.5, n_observations=100,
n_ho_observations=100):
"""
Samples a synthetic linear Gaussian BN instance
with Bayesian Gaussian equivalent (BGe) marginal likelihood
as inference model to weight each DAG in an MEC equally
By marginalizing out the parameters, the BGe model does not
allow inferring the parameters :math:`\\Theta`.
Args:
key (ndarray): rng key
n_vars (int): number of variables i
n_observations (int): number of iid observations of variables
n_ho_observations (int): number of iid held-out observations of variables
graph_prior_str (str): graph prior (``er`` or ``sf``)
bge_mean_obs (float): BGe score prior mean parameter of Normal
bge_alpha_mu (float): BGe score prior precision parameter of Normal
bge_alpha_lambd (float): BGe score prior effective sample size (degrees of freedom parameter of Wishart)
obs_noise (float): observation noise
mean_edge (float): edge weight mean
sig_edge (float): edge weight stddev
min_edge (float): min edge weight enforced by constant shift of sampled parameter
Returns:
tuple(:class:`~dibs.models.BGe`, :class:`~dibs.target.Data`):
BGe inference model and observations from a linear Gaussian generative process
"""
# init models
graph_model = make_graph_model(n_vars=n_vars, graph_prior_str=graph_prior_str)
generative_model = LinearGaussian(
n_vars=n_vars,
obs_noise=obs_noise,
mean_edge=mean_edge,
sig_edge=sig_edge,
min_edge=min_edge,
)
likelihood_model = BGe(
n_vars=n_vars,
mean_obs=bge_mean_obs,
alpha_mu=bge_alpha_mu,
alpha_lambd=bge_alpha_lambd,
)
# sample synthetic BN and observations
key, subk = random.split(key)
data = make_synthetic_bayes_net(
key=subk,
n_vars=n_vars,
graph_model=graph_model,
generative_model=generative_model,
n_observations=n_observations,
n_ho_observations=n_ho_observations,
)
return data, graph_model, likelihood_model
[docs]def make_linear_gaussian_model(*, key, n_vars=20, graph_prior_str='sf',
obs_noise=0.1, mean_edge=0.0, sig_edge=1.0, min_edge=0.5, n_observations=100,
n_ho_observations=100):
"""
Samples a synthetic linear Gaussian BN instance
Args:
key (ndarray): rng key
n_vars (int): number of variables
n_observations (int): number of iid observations of variables
n_ho_observations (int): number of iid held-out observations of variables
graph_prior_str (str): graph prior (`er` or `sf`)
obs_noise (float): observation noise
mean_edge (float): edge weight mean
sig_edge (float): edge weight stddev
min_edge (float): min edge weight enforced by constant shift of sampled parameter
Returns:
tuple(:class:`~dibs.models.LinearGaussian`, :class:`~dibs.target.Data`):
linear Gaussian inference model and observations from a linear Gaussian generative process
"""
# init models
graph_model = make_graph_model(n_vars=n_vars, graph_prior_str=graph_prior_str)
generative_model = LinearGaussian(
n_vars=n_vars,
obs_noise=obs_noise,
mean_edge=mean_edge,
sig_edge=sig_edge,
min_edge=min_edge,
)
likelihood_model = LinearGaussian(
n_vars=n_vars,
obs_noise=obs_noise,
mean_edge=mean_edge,
sig_edge=sig_edge,
min_edge=min_edge,
)
# sample synthetic BN and observations
key, subk = random.split(key)
data = make_synthetic_bayes_net(
key=subk,
n_vars=n_vars,
graph_model=graph_model,
generative_model=generative_model,
n_observations=n_observations,
n_ho_observations=n_ho_observations,
)
return data, graph_model, likelihood_model
[docs]def make_nonlinear_gaussian_model(*, key, n_vars=20, graph_prior_str='sf',
obs_noise=0.1, sig_param=1.0, hidden_layers=(5,), n_observations=100,
n_ho_observations=100):
"""
Samples a synthetic nonlinear Gaussian BN instance
where the local conditional distributions are parameterized
by fully-connected neural networks.
Args:
key (ndarray): rng key
n_vars (int): number of variables
n_observations (int): number of iid observations of variables
n_ho_observations (int): number of iid held-out observations of variables
graph_prior_str (str): graph prior (`er` or `sf`)
obs_noise (float): observation noise
sig_param (float): stddev of the BN parameters,
i.e. here the neural net weights and biases
hidden_layers (tuple): list of ints specifying the hidden layer (sizes)
of the neural nets parameterizatin the local condtitionals
Returns:
tuple(:class:`~dibs.models.DenseNonlinearGaussian`, :class:`~dibs.metrics.Target`):
nonlinear Gaussian inference model and observations from a nonlinear Gaussian generative process
"""
# init models
graph_model = make_graph_model(n_vars=n_vars, graph_prior_str=graph_prior_str)
generative_model = DenseNonlinearGaussian(
n_vars=n_vars,
hidden_layers=hidden_layers,
obs_noise=obs_noise,
sig_param=sig_param,
)
likelihood_model = DenseNonlinearGaussian(
n_vars=n_vars,
hidden_layers=hidden_layers,
obs_noise=obs_noise,
sig_param=sig_param,
)
# sample synthetic BN and observations
key, subk = random.split(key)
data = make_synthetic_bayes_net(
key=subk, n_vars=n_vars,
graph_model=graph_model,
generative_model=generative_model,
n_observations=n_observations,
n_ho_observations=n_ho_observations)
return data, graph_model, likelihood_model