Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
probml
GitHub Repository: probml/pyprobml
Path: blob/master/deprecated/scripts/bayesnet_inf_autodiff_test.py
1192 views
1
2
import superimport
3
4
import numpy as np # original numpy
5
import jax.numpy as jnp
6
from jax import grad
7
import bayesnet_inf_autodiff as bn
8
9
# Example from fig 3 of Darwiche'03 paper
10
# Note that we assume 0=False, 1=True so we order the entries differently
11
12
thetaA = jnp.array([0.5, 0.5]) # thetaA[a] = P(A=a)
13
thetaB = jnp.array([[1.0, 0.0], [0.0, 1.0]]) # thetaB[b,a] = P(B=b|A=a)
14
thetaC = jnp.array([[0.8, 0.2], [0.2, 0.8]]) # thetaC[c,a] = P(C=c|A=a)
15
params = {'A': thetaA, 'B': thetaB, 'C':thetaC}
16
17
cardinality = {name: jnp.shape(cpt)[0] for name, cpt in params.items()}
18
19
dag = {'A':[], 'B':['A'], 'C':['A']}
20
21
assert bn.make_einsum_string(dag) == 'A,B,C,A,BA,CA->'
22
23
#evidence = [1, None, 0] # a=T, c=F
24
evidence = {'A':1, 'C':0}
25
26
evectors = bn.make_evidence_vectors(cardinality, evidence)
27
fe = bn.network_poly(dag, params, evectors) # probability of evidence
28
assert fe==0.1
29
30
# compare numbers to table 1 of Darwiche03
31
f = lambda ev: bn.network_poly(dag, params, ev)
32
grads = grad(f)(evectors) # list of derivatives wrt evectors
33
assert jnp.allclose(grads['A'], [0.4, 0.1]) # A
34
assert jnp.allclose(grads['B'], [0.0, 0.1]) # B
35
assert jnp.allclose(grads['C'], [0.1, 0.4]) # C
36
37
prob_ev, probs = bn.marginal_probs(dag, params, evidence)
38
assert prob_ev==0.1
39
assert jnp.allclose(probs['B'], [0.0, 1.0])
40
41
print('tests passed')
42