Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
probml
GitHub Repository: probml/pyprobml
Path: blob/master/deprecated/scripts/bayesnet_inf_autodiff.py
1192 views
1
# -*- coding: utf-8 -*-
2
"""bayesnet_inf_autodiff.ipynb
3
4
Automatically generated by Colaboratory.
5
6
Original file is located at
7
https://colab.research.google.com/drive/1jlQnVtVH4-eqR6L8ny4iNG5_7C32wCUj
8
"""
9
10
'''
11
Implements inference in a Bayes net using autodiff applied to Z=einsum(factors).
12
[email protected], April 2019
13
Based on "A differential approach to inference in Bayesian networks"
14
Adnan Darwiche, JACM 2003.
15
Cached copy:
16
https://github.com/probml/pyprobml/blob/master/data/darwiche-acm2003-differential.pdf
17
A similar result is shown in
18
"Inside-outside and forward-backward algorithms are just backprop",
19
Jason Eisner (2016).
20
EMNLP Workshop on Structured Prediction for NLP.
21
http://cs.jhu.edu/~jason/papers/eisner.spnlp16.pdf
22
23
The idea of using einsum instead of an arithmetic circuit to compute Z
24
is based on
25
"Tensor Variable Elimination for Plated Factor Graphs"
26
Fritz Obermeyer, Eli Bingham, Martin Jankowiak, Justin Chiu, Neeraj Pradhan, Alexander Rush, Noah Goodman
27
Arxiv 2019 (https://arxiv.org/abs/1902.03210)
28
29
30
For a demo of how to use this code, see
31
https://github.com/probml/pyprobml/blob/master/book/student_pgm_inf_autodiff.py
32
For a unit test see
33
https://github.com/probml/pyprobml/blob/master/book/bayesnet_inf_autodiff_test.py
34
35
Darwiche defines the network polynomial f(l) = poly(l, theta),
36
where l(i,j)=1 if variable i is in state j; these
37
are called the evidence vectors, denoted by lambda.
38
Let e be a vector of observations for a (sub)set of the nodes.
39
Let o(i)=1 if variable i is observed in e, and o(i)=0 otherwise.
40
So l(i,j)=ind{j=e(i)) if o(i)=1, and l(i,:)=1 otherwise.
41
Thus l(i,j)=1 means the setting x(i)=j is compatible with e.
42
Define f(e) = f(l(e)), where l(e) is this binding process.
43
Thm 1: f(l(e)) = Pr(x(o)=e) = Pr(e), the probability of the evidence.
44
f(e) is also denoted by Z, the normalization constant in Bayes rule.
45
Note that we can compute f(e) using einstein summation over all terms
46
in the network poly.
47
Thm 2: let g_{i,j}(l)= d/dl(i,j) f(l(e)) be partial derivative.
48
so g_i(l) is the gradient vector for variable i.
49
Then g_ij(l) = Pr(x(i)=j, x(o(-i))=e(-i)), where o(-i) are all observed
50
variables except i.
51
Corollary 1. d/dl(i,j) log f(l(e)) = 1/f(l(e)) * g_{ij}(l(e))
52
= Pr(x(i)=j | e) if o(i)=0 (so i not in e).
53
This is the standard result that derivatives of the log partition function
54
gives the expected sufficient statistics, which for a multinomial
55
are the posterior marignals over states.
56
We use jax (https://github.com/google/jax) to compute the partial
57
derivatives. This requires that f(e) be implemented using jax's
58
version of einsum, which fortunately is 100% compatible with the numpy
59
version.
60
'''
61
62
63
import superimport
64
65
import numpy as np # original numpy
66
import jax.numpy as jnp
67
from jax import grad, jit, vmap
68
from jax.ops import index, index_add, index_update
69
from functools import partial
70
71
def make_einsum_string(dag):
72
# example: if dag is B <- A -> C, returns 'A,B,C, A,BA,CA->'
73
node_names = list(dag.keys())
74
cpt_names = [n + ''.join(dag[n]) for n in node_names] # indices for CPTs
75
str = ','.join(node_names) + ',' + ','.join(cpt_names) + '->'
76
return str
77
78
def make_list_of_factors(dag, params, evectors):
79
# Extract dictionary elements in same order as einsum string
80
node_names = list(dag.keys())
81
evecs = []
82
cpts = []
83
for n in node_names:
84
evecs.append(evectors[n])
85
cpts.append(params[n])
86
return (evecs+cpts)
87
88
def network_poly(dag, params, evectors, elim_order=None):
89
# Sum over all assignments to network polynomial to compute Z
90
str = make_einsum_string(dag)
91
factors = make_list_of_factors(dag, params, evectors)
92
if elim_order is None:
93
return jnp.einsum(str, *factors)
94
else:
95
return jnp.einsum(str, *factors, optimize=elim_order)
96
97
def make_evidence_vectors(cardinality, evidence):
98
# compute l(i,j)=1 iff x(i)=j is compatible with evidence e
99
def f(nstates, val):
100
if val == -1:
101
vec = jnp.ones(nstates)
102
else:
103
#vec[val] = 1.0 # not allowed to mutate state in jax
104
vec = index_update(np.zeros(nstates), index[val], 1.0) # functional assignment
105
return vec
106
return {name: f(nstates, evidence.get(name, -1)) for name, nstates in cardinality.items()}
107
108
def marginal_probs(dag, params, evidence, elim_order=None):
109
# Compute marginal probabilities of all nodes in a Bayesnet.
110
cardinality = {name: jnp.shape(CPT)[0] for name, CPT in params.items()}
111
evectors = make_evidence_vectors(cardinality, evidence)
112
f = lambda ev: network_poly(dag, params, ev, elim_order) # clamp model parameters
113
prob_ev = f(evectors)
114
grads = grad(f)(evectors) # list of derivatives wrt evectors
115
probs = dict()
116
for name in dag.keys():
117
ev = evidence.get(name, -1)
118
if ev == -1: # not observed
119
probs[name] = grads[name] / prob_ev
120
else:
121
probs[name] = evectors[name] # clamped node
122
probs[name] = np.array(probs[name]) # cast back to vanilla numpy array
123
return prob_ev, probs
124
125
126
127
def compute_elim_order(dag, params):
128
# compute optimal elimination order assuming no nodes are observed
129
evidence = {}
130
cardinality = {name: jnp.shape(CPT)[0] for name, CPT in params.items()}
131
evectors = make_evidence_vectors(cardinality, evidence)
132
str = make_einsum_string(dag)
133
factors = make_list_of_factors(dag, params, evectors)
134
nnodes = len(dag.keys())
135
#print('computing elimination order for DAG with {} nodes'.format(nnodes))
136
#elim_order = jnp.einsum_path(str, *factors, optimize='optimal')[0]
137
elim_order = jnp.einsum_path(str, *factors, optimize='greedy')[0]
138
return elim_order
139
140
# Class that provides syntactic sugar on top of above functions.
141
class BayesNetInfAutoDiff:
142
def __init__(self, dag, params):
143
self._dag = dag
144
self._params = params
145
self._elim_order = compute_elim_order(dag, params)
146
#self._elim_order = None
147
148
def infer_marginals(self, evidence):
149
prob_ev, marginals = marginal_probs(self._dag, self._params, evidence,
150
self._elim_order)
151
return marginals
152