Source code for dibs.models.nonlinearGaussian

import os
import numpy as onp

import jax.numpy as jnp
from jax import vmap
from jax import random
from jax.scipy.stats import norm as jax_normal
from jax.tree_util import tree_map, tree_reduce
from jax.nn.initializers import normal
import jax.example_libraries.stax as stax
from jax.example_libraries.stax import Dense, Sigmoid, LeakyRelu, Relu, Tanh

from dibs.graph_utils import graph_to_mat
from dibs.utils.tree import tree_shapes


def dense_no_bias(out_dim, w_init=None):
    """Layer constructor function for a dense (fully-connected) layer _without_ bias"""

    if w_init is None:
        w_init = normal()

    def init_fun(rng, input_shape):
        output_shape = input_shape[:-1] + (out_dim,)
        w = w_init(rng, (input_shape[-1], out_dim))
        return output_shape, (w, )

    def apply_fun(params, inputs, **kwargs):
        w, = params
        return jnp.dot(inputs, w)

    return init_fun, apply_fun


def make_dense_net(*, hidden_layers, sig_weight, sig_bias, bias=True, activation='relu'):
    """
    Generates functions defining a fully-connected NN
    with Gaussian initialized parameters

    Args:
        hidden_layers (tuple): list of ints specifying the dimensions of the hidden sizes
        sig_weight: std dev of weight initialization
        sig_bias: std dev of weight initialization
        bias: bias of linear layer
        activation: activation function str; choices: `sigmoid`, `tanh`, `relu`, `leakyrelu`

    Returns:
        stax.serial neural net object
    """

    # features: [hidden_layers[0], hidden_layers[0], ..., hidden_layers[-1], 1]
    if activation == 'sigmoid':
        f_activation = Sigmoid
    elif activation == 'tanh':
        f_activation = Tanh
    elif activation == 'relu':
        f_activation = Relu
    elif activation == 'leakyrelu':
        f_activation = LeakyRelu
    else:
        raise KeyError(f'Invalid activation function `{activation}`')

    modules = []
    if bias:
        for dim in hidden_layers:
            modules += [
                Dense(dim, W_init=normal(stddev=sig_weight),
                        b_init=normal(stddev=sig_bias)),
                f_activation
            ]
        modules += [Dense(1, W_init=normal(stddev=sig_weight),
                            b_init=normal(stddev=sig_bias))]
    else:
        for dim in hidden_layers:
            modules += [
                dense_no_bias(dim, w_init=normal(stddev=sig_weight)),
                f_activation
            ]
        modules += [dense_no_bias(1, w_init=normal(stddev=sig_weight))]

    return stax.serial(*modules)


[docs]class DenseNonlinearGaussian: """ Nonlinear Gaussian BN model corresponding to a nonlinaer structural equation model (SEM) with additive Gaussian noise. Each variable distributed as Gaussian with mean parameterized by a dense neural network (MLP) whose weights and biases are sampled from a Gaussian prior. The noise variance at each node is equal by default. Refer to http://proceedings.mlr.press/v108/zheng20a/zheng20a.pdf Args: n_vars (int): number of variables (nodes in the graph) hidden_layers (tuple): list of integers specifying the number of layers as well as their widths. For example: ``[8, 8]`` would correspond to 2 hidden layers with 8 neurons obs_noise (float, optional): variance of additive observation noise at nodes sig_param (float, optional): std dev of Gaussian parameter prior activation (str, optional): identifier for activation function. Choices: ``sigmoid``, ``tanh``, ``relu``, ``leakyrelu`` """ def __init__(self, *, n_vars, hidden_layers, obs_noise=0.1, sig_param=1.0, activation='relu', bias=True): self.n_vars = n_vars self.obs_noise = obs_noise self.sig_param = sig_param self.hidden_layers = hidden_layers self.activation = activation self.bias = bias self.no_interv_targets = jnp.zeros(self.n_vars).astype(bool) # init single neural net function for one variable with jax stax self.nn_init_random_params, nn_forward = make_dense_net( hidden_layers=self.hidden_layers, sig_weight=self.sig_param, sig_bias=self.sig_param, activation=self.activation, bias=self.bias) # [?], [N, d] -> [N,] self.nn_forward = lambda theta, x: nn_forward(theta, x).squeeze(-1) # vectorize init and forward functions self.eltwise_nn_init_random_params = vmap(self.nn_init_random_params, (0, None), 0) self.double_eltwise_nn_init_random_params = vmap(self.eltwise_nn_init_random_params, (0, None), 0) self.triple_eltwise_nn_init_random_params = vmap(self.double_eltwise_nn_init_random_params, (0, None), 0) # [d2, ?], [N, d] -> [N, d2] self.eltwise_nn_forward = vmap(self.nn_forward, (0, None), 1) # [d2, ?], [d2, N, d] -> [N, d2] self.double_eltwise_nn_forward = vmap(self.nn_forward, (0, 0), 1)
[docs] def get_theta_shape(self, *, n_vars): """Returns tree shape of the parameters of the neural networks Args: n_vars (int): number of variables in model Returns: PyTree of parameter shape """ dummy_subkeys = jnp.zeros((n_vars, 2), dtype=jnp.uint32) _, theta = self.eltwise_nn_init_random_params(dummy_subkeys, (n_vars, )) # second arg is `input_shape` of NN forward pass theta_shape = tree_shapes(theta) return theta_shape
[docs] def sample_parameters(self, *, key, n_vars, n_particles=0, batch_size=0): """Samples batch of random parameters given dimensions of graph from :math:`p(\\Theta | G)` Args: key (ndarray): rng n_vars (int): number of variables in BN n_particles (int): number of parameter particles sampled batch_size (int): number of batches of particles being sampled Returns: Parameter PyTree with leading dimension(s) ``batch_size`` and/or ``n_particles``, dropping either dimension when equal to 0 """ shape = [d for d in (batch_size, n_particles, n_vars) if d != 0] subkeys = random.split(key, int(onp.prod(shape))).reshape(*shape, 2) if len(shape) == 1: _, theta = self.eltwise_nn_init_random_params(subkeys, (n_vars, )) elif len(shape) == 2: _, theta = self.double_eltwise_nn_init_random_params(subkeys, (n_vars, )) elif len(shape) == 3: _, theta = self.triple_eltwise_nn_init_random_params(subkeys, (n_vars, )) else: raise ValueError(f"invalid shape size for nn param initialization {shape}") # to float64 prec64 = 'JAX_ENABLE_X64' in os.environ and os.environ['JAX_ENABLE_X64'] == 'True' theta = tree_map(lambda arr: arr.astype(jnp.float64 if prec64 else jnp.float32), theta) return theta
[docs] def sample_obs(self, *, key, n_samples, g, theta, toporder=None, interv=None): """Samples ``n_samples`` observations given graph ``g`` and parameters ``theta`` by doing single forward passes in topological order Args: key (ndarray): rng n_samples (int): number of samples g (igraph.Graph): graph theta (Any): parameters interv (dict): intervention specification of the form ``{intervened node : clamp value}`` Returns: observation matrix of shape ``[n_samples, n_vars]`` """ if interv is None: interv = {} if toporder is None: toporder = g.topological_sorting() n_vars = len(g.vs) x = jnp.zeros((n_samples, n_vars)) key, subk = random.split(key) z = jnp.sqrt(self.obs_noise) * random.normal(subk, shape=(n_samples, n_vars)) g_mat = graph_to_mat(g) # ancestral sampling # for simplicity, does d full forward passes for simplicity, which avoids indexing into python list of parameters for j in toporder: # intervention if j in interv.keys(): x = x.at[:, j].set(interv[j]) continue # regular ancestral sampling parents = g_mat[:, j].reshape(1, -1) has_parents = parents.sum() > 0 if has_parents: # [N, d] = [N, d] * [1, d] mask non-parent entries of j x_msk = x * parents # [N, d] full forward pass means = self.eltwise_nn_forward(theta, x_msk) # [N,] update j only x = x.at[:, j].set(means[:, j] + z[:, j]) else: x = x.at[:, j].set(z[:, j]) return x
""" The following functions need to be functionally pure and @jit-able """
[docs] def log_prob_parameters(self, *, theta, g): """Computes parameter prior :math:`\\log p(\\Theta | G)`` In this model, the prior over weights and biases is zero-centered Gaussian. Arguments: theta (Any): parameter pytree g (ndarray): graph adjacency matrix of shape ``[n_vars, n_vars]`` Returns: log prob """ # compute log prob for each weight logprobs = tree_map(lambda leaf_theta: jax_normal.logpdf(x=leaf_theta, loc=0.0, scale=self.sig_param), theta) # mask logprobs of first layer weight matrix [0][0] according to graph # [d, d, dim_first_layer] = [d, d, dim_first_layer] * [d, d, 1] if self.bias: first_weight_logprobs, first_bias_logprobs = logprobs[0] logprobs[0] = (first_weight_logprobs * g.T[:, :, None], first_bias_logprobs) else: first_weight_logprobs, = logprobs[0] logprobs[0] = (first_weight_logprobs * g.T[:, :, None],) # sum logprobs of every parameter tensor and add all up return tree_reduce(jnp.add, tree_map(jnp.sum, logprobs))
[docs] def log_likelihood(self, *, x, theta, g, interv_targets): """Computes likelihood :math:`p(D | G, \\Theta)`. In this model, the noise per observation and node is additive and Gaussian. Arguments: x (ndarray): observations of shape ``[n_observations, n_vars]`` theta (Any): parameters PyTree g (ndarray): graph adjacency matrix of shape ``[n_vars, n_vars]`` interv_targets (ndarray): binary intervention indicator vector of shape ``[n_vars, ]`` Returns: log prob """ assert x.shape == interv_targets.shape # [d2, N, d] = [1, N, d] * [d2, 1, d] mask non-parent entries of each j all_x_msk = x[None] * g.T[:, None] # [N, d2] NN forward passes for parameters of each param j all_means = self.double_eltwise_nn_forward(theta, all_x_msk) # sum scores for all nodes and data return jnp.sum( jnp.where( # [n_observations, n_vars] interv_targets, 0.0, # [n_observations, n_vars] jax_normal.logpdf(x=x, loc=all_means, scale=jnp.sqrt(self.obs_noise)) ) )
""" Distributions used by DiBS for inference: prior and joint likelihood """
[docs] def interventional_log_joint_prob(self, g, theta, x, interv_targets, rng): """Computes interventional joint likelihood :math:`\\log p(\\Theta, D | G)`` Arguments: g (ndarray): graph adjacency matrix of shape ``[n_vars, n_vars]`` theta (Any): parameter PyTree x (ndarray): observational data of shape ``[n_observations, n_vars]`` interv_targets (ndarray): indicator mask of interventions of shape ``[n_observations, n_vars]`` rng (ndarray): rng; skeleton for minibatching (TBD) Returns: log prob of shape ``[1,]`` """ log_prob_theta = self.log_prob_parameters(g=g, theta=theta) log_likelihood = self.log_likelihood(g=g, theta=theta, x=x, interv_targets=interv_targets) return log_prob_theta + log_likelihood