Path: blob/master/deprecated/scripts/bayesnet_inf_autodiff_test.py
1192 views
1import superimport23import numpy as np # original numpy4import jax.numpy as jnp5from jax import grad6import bayesnet_inf_autodiff as bn78# Example from fig 3 of Darwiche'03 paper9# Note that we assume 0=False, 1=True so we order the entries differently1011thetaA = jnp.array([0.5, 0.5]) # thetaA[a] = P(A=a)12thetaB = jnp.array([[1.0, 0.0], [0.0, 1.0]]) # thetaB[b,a] = P(B=b|A=a)13thetaC = jnp.array([[0.8, 0.2], [0.2, 0.8]]) # thetaC[c,a] = P(C=c|A=a)14params = {'A': thetaA, 'B': thetaB, 'C':thetaC}1516cardinality = {name: jnp.shape(cpt)[0] for name, cpt in params.items()}1718dag = {'A':[], 'B':['A'], 'C':['A']}1920assert bn.make_einsum_string(dag) == 'A,B,C,A,BA,CA->'2122#evidence = [1, None, 0] # a=T, c=F23evidence = {'A':1, 'C':0}2425evectors = bn.make_evidence_vectors(cardinality, evidence)26fe = bn.network_poly(dag, params, evectors) # probability of evidence27assert fe==0.12829# compare numbers to table 1 of Darwiche0330f = lambda ev: bn.network_poly(dag, params, ev)31grads = grad(f)(evectors) # list of derivatives wrt evectors32assert jnp.allclose(grads['A'], [0.4, 0.1]) # A33assert jnp.allclose(grads['B'], [0.0, 0.1]) # B34assert jnp.allclose(grads['C'], [0.1, 0.4]) # C3536prob_ev, probs = bn.marginal_probs(dag, params, evidence)37assert prob_ev==0.138assert jnp.allclose(probs['B'], [0.0, 1.0])3940print('tests passed')4142