import igraph as ig
import random as pyrandom
import jax.numpy as jnp
from jax import random
from dibs.graph_utils import mat_to_graph, graph_to_mat, mat_is_dag
from dibs.utils.func import zero_diagonal
[docs]class ErdosReniDAGDistribution:
"""
Randomly oriented Erdos-Reni random graph model with i.i.d. edge probability.
The pmf is defined as
:math:`p(G) \\propto p^e (1-p)^{\\binom{d}{2} - e}`
where :math:`e` denotes the total number of edges in :math:`G`
and :math:`p` is chosen to satisfy the requirement of sampling ``n_edges_per_node``
edges per node in expectation.
Args:
n_vars (int): number of variables in DAG
n_edges_per_node (int): number of edges sampled per variable in expectation
"""
def __init__(self, n_vars, n_edges_per_node=2):
self.n_vars = n_vars
self.n_edges = n_edges_per_node * n_vars
self.p = self.n_edges / ((self.n_vars * (self.n_vars - 1)) / 2)
[docs] def sample_G(self, key, return_mat=False):
"""Samples DAG
Args:
key (ndarray): rng
return_mat (bool): if ``True``, returns adjacency matrix of shape ``[n_vars, n_vars]``
Returns:
``iGraph.graph`` / ``jnp.array``:
DAG
"""
key, subk = random.split(key)
mat = random.bernoulli(subk, p=self.p, shape=(self.n_vars, self.n_vars)).astype(jnp.int32)
# make DAG by zeroing above diagonal; k=-1 indicates that diagonal is zero too
dag = jnp.tril(mat, k=-1)
# randomly permute
key, subk = random.split(key)
P = random.permutation(subk, jnp.eye(self.n_vars, dtype=jnp.int32))
dag_perm = P.T @ dag @ P
if return_mat:
return dag_perm
else:
g = mat_to_graph(dag_perm)
return g
[docs] def unnormalized_log_prob_single(self, *, g, j):
"""
Computes :math:`\\log p(G_j)` up the normalization constant
Args:
g (iGraph.graph): graph
j (int): node index:
Returns:
unnormalized log probability of node family of :math:`j`
"""
parent_edges = g.incident(j, mode='in')
n_parents = len(parent_edges)
return n_parents * jnp.log(self.p) + (self.n_vars - n_parents - 1) * jnp.log(1 - self.p)
[docs] def unnormalized_log_prob(self, *, g):
"""
Computes :math:`\\log p(G)` up the normalization constant
Args:
g (iGraph.graph): graph
Returns:
unnormalized log probability of :math:`G`
"""
N = self.n_vars * (self.n_vars - 1) / 2.0
E = len(g.es)
return E * jnp.log(self.p) + (N - E) * jnp.log(1 - self.p)
[docs] def unnormalized_log_prob_soft(self, *, soft_g):
"""
Computes :math:`\\log p(G)` up the normalization constant
where :math:`G` is the matrix of edge probabilities
Args:
soft_g (ndarray): graph adjacency matrix, where entries
may be probabilities and not necessarily 0 or 1
Returns:
unnormalized log probability corresponding to edge probabilities in :math:`G`
"""
N = self.n_vars * (self.n_vars - 1) / 2.0
E = soft_g.sum()
return E * jnp.log(self.p) + (N - E) * jnp.log(1 - self.p)
[docs]class ScaleFreeDAGDistribution:
"""
Randomly-oriented scale-free random graph with power-law degree distribution.
The pmf is defined as
:math:`p(G) \\propto \\prod_j (1 + \\text{deg}(j))^{-3}`
where :math:`\\text{deg}(j)` denotes the in-degree of node :math:`j`
Args:
n_vars (int): number of variables in DAG
n_edges_per_node (int): number of edges sampled per variable
"""
def __init__(self, n_vars, verbose=False, n_edges_per_node=2):
self.n_vars = n_vars
self.n_edges_per_node = n_edges_per_node
self.verbose = verbose
[docs] def sample_G(self, key, return_mat=False):
"""Samples DAG
Args:
key (ndarray): rng
return_mat (bool): if ``True``, returns adjacency matrix of shape ``[n_vars, n_vars]``
Returns:
``iGraph.graph`` / ``jnp.array``:
DAG
"""
pyrandom.seed(int(key.sum()))
perm = random.permutation(key, self.n_vars).tolist()
g = ig.Graph.Barabasi(n=self.n_vars, m=self.n_edges_per_node, directed=True).permute_vertices(perm)
if return_mat:
return graph_to_mat(g)
else:
return g
[docs] def unnormalized_log_prob_single(self, *, g, j):
"""
Computes :math:`\\log p(G_j)` up the normalization constant
Args:
g (iGraph.graph): graph
j (int): node index:
Returns:
unnormalized log probability of node family of :math:`j`
"""
parent_edges = g.incident(j, mode='in')
n_parents = len(parent_edges)
return -3 * jnp.log(1 + n_parents)
[docs] def unnormalized_log_prob(self, *, g):
"""
Computes :math:`\\log p(G)` up the normalization constant
Args:
g (iGraph.graph): graph
Returns:
unnormalized log probability of :math:`G`
"""
return jnp.array([self.unnormalized_log_prob_single(g=g, j=j) for j in range(self.n_vars)]).sum()
[docs] def unnormalized_log_prob_soft(self, *, soft_g):
"""
Computes :math:`\\log p(G)` up the normalization constant
where :math:`G` is the matrix of edge probabilities
Args:
soft_g (ndarray): graph adjacency matrix, where entries
may be probabilities and not necessarily 0 or 1
Returns:
unnormalized log probability corresponding to edge probabilities in :math:`G`
"""
soft_indegree = soft_g.sum(0)
return jnp.sum(-3 * jnp.log(1 + soft_indegree))
class UniformDAGDistributionRejection:
"""
Naive implementation of a uniform distribution over DAGs via rejection
sampling. This is efficient up to roughly :math:`d = 5`.
Properly sampling a uniformly-random DAG is possible but nontrivial
and not implemented here.
Args:
n_vars (int): number of variables in DAG
"""
def __init__(self, n_vars):
self.n_vars = n_vars
def sample_G(self, key, return_mat=False):
"""Samples DAG
Args:
key (ndarray): rng
return_mat (bool): if ``True``, returns adjacency matrix of shape ``[n_vars, n_vars]``
Returns:
``iGraph.graph`` / ``jnp.array``:
DAG
"""
while True:
key, subk = random.split(key)
mat = random.bernoulli(subk, p=0.5, shape=(self.n_vars, self.n_vars)).astype(jnp.int32)
mat = zero_diagonal(mat)
if mat_is_dag(mat):
if return_mat:
return mat
else:
return mat_to_graph(mat)
def unnormalized_log_prob_single(self, *, g, j):
"""
Computes :math:`\\log p(G_j)` up the normalization constant
Args:
g (iGraph.graph): graph
j (int): node index:
Returns:
unnormalized log probability of node family of :math:`j`
"""
return jnp.array(0.0)
def unnormalized_log_prob(self, *, g):
"""
Computes :math:`\\log p(G)` up the normalization constant
Args:
g (iGraph.graph): graph
Returns:
unnormalized log probability of :math:`G`
"""
return jnp.array(0.0)
def unnormalized_log_prob_soft(self, *, soft_g):
"""
Computes :math:`\\log p(G)` up the normalization constant
where :math:`G` is the matrix of edge probabilities
Args:
soft_g (ndarray): graph adjacency matrix, where entries
may be probabilities and not necessarily 0 or 1
Returns:
unnormalized log probability corresponding to edge probabilities in :math:`G`
"""
return jnp.array(0.0)