Source code for dibs.models.linearGaussian

import jax.numpy as jnp
from jax import random, vmap
from jax.scipy.stats import norm as jax_normal
from jax.scipy.special import gammaln
from dibs.utils.func import _slogdet_jax


[docs]class BGe: """ Linear Gaussian BN model corresponding to linear structural equation model (SEM) with additive Gaussian noise. Uses Normal-Wishart conjugate parameter prior to allow for closed-form marginal likelihood :math:`\\log p(D | G)` and thus allows inference of the marginal posterior :math:`p(G | D)` For details on the closed-form expression, refer to - Geiger et al. (2002): https://projecteuclid.org/download/pdf_1/euclid.aos/1035844981 - Kuipers et al. (2014): https://projecteuclid.org/download/suppdf_1/euclid.aos/1407420013 The default arguments imply commonly-used default hyperparameters for mean and precision of the Normal-Wishart and assume a diagonal parameter matrix :math:`T`. Inspiration for the implementation was drawn from https://bitbucket.org/jamescussens/pygobnilp/src/master/pygobnilp/scoring.py This implementation uses properties of the determinant to make the computation of the marginal likelihood ``jax.jit``-compilable and ``jax.grad``-differentiable by remaining well-defined for soft relaxations of the graph. Args: n_vars (int): number of variables (nodes in the graph) mean_obs (ndarray, optional): mean parameter of Normal alpha_mu (float, optional): precision parameter of Normal alpha_lambd (float, optional): degrees of freedom parameter of Wishart """ def __init__(self, *, n_vars, mean_obs=None, alpha_mu=None, alpha_lambd=None, ): self.n_vars = n_vars self.mean_obs = mean_obs or jnp.zeros(self.n_vars) self.alpha_mu = alpha_mu or 1.0 self.alpha_lambd = alpha_lambd or (self.n_vars + 2) assert self.alpha_lambd > self.n_vars + 1 self.no_interv_targets = jnp.zeros(self.n_vars).astype(bool) def get_theta_shape(self, *, n_vars): raise NotImplementedError("Not available for BGe score; use `LinearGaussian` model instead.") def sample_parameters(self, *, key, n_vars, n_particles=0, batch_size=0): raise NotImplementedError("Not available for BGe score; use `LinearGaussian` model instead.") def sample_obs(self, *, key, n_samples, g, theta, toporder=None, interv=None): raise NotImplementedError("Not available for BGe score; use `LinearGaussian` model instead.") """ The following functions need to be functionally pure and jax.jit-compilable """ def _log_marginal_likelihood_single(self, j, n_parents, g, x, interv_targets): """ Computes node-specific score of BGe marginal likelihood. ``jax.jit``-compilable Args: j (int): node index for score n_parents (int): number of parents of node ``j`` g (ndarray): adjacency matrix of shape ``[d, d] x (ndarray): observations matrix of shape ``[N, d]`` interv_targets (ndarray): intervention indicator matrix of shape ``[N, d]`` Returns: BGe score for node ``j`` """ d = x.shape[-1] small_t = (self.alpha_mu * (self.alpha_lambd - d - 1)) / (self.alpha_mu + 1) T = small_t * jnp.eye(d) # mask rows of `x` where j is intervened upon to 0.0 and compute (remaining) number of observations `N` x = x * (1 - interv_targets[..., j, None]) N = (1 - interv_targets[..., j]).sum() # covariance matrix of non-intervened rows x_bar = jnp.where(jnp.isclose(N, 0), jnp.zeros((1, d)), x.sum(axis=0, keepdims=True) / N) x_center = (x - x_bar) * (1 - interv_targets[..., j, None]) s_N = x_center.T @ x_center # [d, d] # Kuipers et al. (2014) state `R` wrongly in the paper, using `alpha_lambd` rather than `alpha_mu` # their supplementary contains the correct term R = T + s_N + ((N * self.alpha_mu) / (N + self.alpha_mu)) * \ ((x_bar - self.mean_obs).T @ (x_bar - self.mean_obs)) # [d, d] parents = g[:, j] parents_and_j = (g + jnp.eye(d))[:, j] log_gamma_term = ( 0.5 * (jnp.log(self.alpha_mu) - jnp.log(N + self.alpha_mu)) + gammaln(0.5 * (N + self.alpha_lambd - d + n_parents + 1)) - gammaln(0.5 * (self.alpha_lambd - d + n_parents + 1)) - 0.5 * N * jnp.log(jnp.pi) # log det(T_JJ)^(..) / det(T_II)^(..) for default T + 0.5 * (self.alpha_lambd - d + 2 * n_parents + 1) * jnp.log(small_t) ) log_term_r = ( # log det(R_II)^(..) / det(R_JJ)^(..) 0.5 * (N + self.alpha_lambd - d + n_parents) * _slogdet_jax(R, parents) - 0.5 * (N + self.alpha_lambd - d + n_parents + 1) * _slogdet_jax(R, parents_and_j) ) # return neutral sum element (0) if no observations (N=0) return jnp.where(jnp.isclose(N, 0), 0.0, log_gamma_term + log_term_r)
[docs] def log_marginal_likelihood(self, *, g, x, interv_targets): """Computes BGe marginal likelihood :math:`\\log p(D | G)`` in closed-form; ``jax.jit``-compatible Args: g (ndarray): adjacency matrix of shape ``[d, d]`` x (ndarray): observations of shape ``[N, d]`` interv_targets (ndarray): boolean mask of shape ``[N, d]`` of whether or not a node was intervened upon in a given sample. Intervened nodes are ignored in likelihood computation Returns: BGe Score """ # indices _, d = x.shape nodes_idx = jnp.arange(d) # number of parents for each node n_parents_all = g.sum(axis=0) # sum scores for all nodes [d,] scores = vmap(self._log_marginal_likelihood_single, (0, 0, None, None, None), 0)(nodes_idx, n_parents_all, g, x, interv_targets) return scores.sum(0)
""" Distributions used by DiBS for inference: prior and marginal likelihood """
[docs] def interventional_log_marginal_prob(self, g, _, x, interv_targets, rng): """Computes interventional marginal likelihood :math:`\\log p(D | G)`` in closed-form; ``jax.jit``-compatible To unify the function signatures for the marginal and joint inference classes :class:`~dibs.inference.MarginalDiBS` and :class:`~dibs.inference.JointDiBS`, this marginal likelihood is defined with dummy ``theta`` inputs as ``_``, i.e., like a joint likelihood Arguments: g (ndarray): graph adjacency matrix of shape ``[n_vars, n_vars]``. Entries must be binary and of type ``jnp.int32`` _: x (ndarray): observational data of shape ``[n_observations, n_vars]`` interv_targets (ndarray): indicator mask of interventions of shape ``[n_observations, n_vars]`` rng (ndarray): rng; skeleton for minibatching (TBD) Returns: BGe score of shape ``[1,]`` """ return self.log_marginal_likelihood(g=g, x=x, interv_targets=interv_targets)
[docs]class LinearGaussian: """ Linear Gaussian BN model corresponding to linear structural equation model (SEM) with additive Gaussian noise. Each variable distributed as Gaussian with mean being the linear combination of its parents weighted by a Gaussian parameter vector (i.e., with Gaussian-valued edges). The noise variance at each node is equal by default, which implies the causal structure is identifiable. Args: n_vars (int): number of variables (nodes in the graph) obs_noise (float, optional): variance of additive observation noise at nodes mean_edge (float, optional): mean of Gaussian edge weight sig_edge (float, optional): std dev of Gaussian edge weight min_edge (float, optional): minimum linear effect of parent on child """ def __init__(self, *, n_vars, obs_noise=0.1, mean_edge=0.0, sig_edge=1.0, min_edge=0.5): self.n_vars = n_vars self.obs_noise = obs_noise self.mean_edge = mean_edge self.sig_edge = sig_edge self.min_edge = min_edge self.no_interv_targets = jnp.zeros(self.n_vars).astype(bool)
[docs] def get_theta_shape(self, *, n_vars): """Returns tree shape of the parameters of the linear model Args: n_vars (int): number of variables in model Returns: PyTree of parameter shape """ return jnp.array((n_vars, n_vars))
[docs] def sample_parameters(self, *, key, n_vars, n_particles=0, batch_size=0): """Samples batch of random parameters given dimensions of graph from :math:`p(\\Theta | G)` Args: key (ndarray): rng n_vars (int): number of variables in BN n_particles (int): number of parameter particles sampled batch_size (int): number of batches of particles being sampled Returns: Parameters ``theta`` of shape ``[batch_size, n_particles, n_vars, n_vars]``, dropping dimensions equal to 0 """ shape = (batch_size, n_particles, *self.get_theta_shape(n_vars=n_vars)) theta = self.mean_edge + self.sig_edge * random.normal(key, shape=tuple(d for d in shape if d != 0)) theta += jnp.sign(theta) * self.min_edge return theta
[docs] def sample_obs(self, *, key, n_samples, g, theta, toporder=None, interv=None): """Samples ``n_samples`` observations given graph ``g`` and parameters ``theta`` Args: key (ndarray): rng n_samples (int): number of samples g (igraph.Graph): graph theta (Any): parameters interv (dict): intervention specification of the form ``{intervened node : clamp value}`` Returns: observation matrix of shape ``[n_samples, n_vars]`` """ if interv is None: interv = {} if toporder is None: toporder = g.topological_sorting() x = jnp.zeros((n_samples, len(g.vs))) key, subk = random.split(key) z = jnp.sqrt(self.obs_noise) * random.normal(subk, shape=(n_samples, len(g.vs))) # ancestral sampling for j in toporder: # intervention if j in interv.keys(): x = x.at[:, j].set(interv[j]) continue # regular ancestral sampling parent_edges = g.incident(j, mode='in') parents = list(g.es[e].source for e in parent_edges) if parents: mean = x[:, jnp.array(parents)] @ theta[jnp.array(parents), j] x = x.at[:, j].set(mean + z[:, j]) else: x = x.at[:, j].set(z[:, j]) return x
""" The following functions need to be functionally pure and @jit-able """
[docs] def log_prob_parameters(self, *, theta, g): """Computes parameter prior :math:`\\log p(\\Theta | G)`` In this model, the parameter prior is Gaussian. Arguments: theta (ndarray): parameter matrix of shape ``[n_vars, n_vars]`` g (ndarray): graph adjacency matrix of shape ``[n_vars, n_vars]`` Returns: log prob """ return jnp.sum(g * jax_normal.logpdf(x=theta, loc=self.mean_edge, scale=self.sig_edge))
[docs] def log_likelihood(self, *, x, theta, g, interv_targets): """Computes likelihood :math:`p(D | G, \\Theta)`. In this model, the noise per observation and node is additive and Gaussian. Arguments: x (ndarray): observations of shape ``[n_observations, n_vars]`` theta (ndarray): parameters of shape ``[n_vars, n_vars]`` g (ndarray): graph adjacency matrix of shape ``[n_vars, n_vars]`` interv_targets (ndarray): binary intervention indicator vector of shape ``[n_observations, n_vars]`` Returns: log prob """ assert x.shape == interv_targets.shape # sum scores for all nodes and data return jnp.sum( jnp.where( # [n_observations, n_vars] interv_targets, 0.0, # [n_observations, n_vars] jax_normal.logpdf(x=x, loc=x @ (g * theta), scale=jnp.sqrt(self.obs_noise)) ) )
""" Distributions used by DiBS for inference: prior and joint likelihood """
[docs] def interventional_log_joint_prob(self, g, theta, x, interv_targets, rng): """Computes interventional joint likelihood :math:`\\log p(\\Theta, D | G)`` Arguments: g (ndarray): graph adjacency matrix of shape ``[n_vars, n_vars]`` theta (ndarray): parameter matrix of shape ``[n_vars, n_vars]`` x (ndarray): observational data of shape ``[n_observations, n_vars]`` interv_targets (ndarray): indicator mask of interventions of shape ``[n_observations, n_vars]`` rng (ndarray): rng; skeleton for minibatching (TBD) Returns: log prob of shape ``[1,]`` """ log_prob_theta = self.log_prob_parameters(g=g, theta=theta) log_likelihood = self.log_likelihood(g=g, theta=theta, x=x, interv_targets=interv_targets) return log_prob_theta + log_likelihood