Path: blob/master/deprecated/scripts/bayesnet_inf_autodiff.py
1192 views
# -*- coding: utf-8 -*-1"""bayesnet_inf_autodiff.ipynb23Automatically generated by Colaboratory.45Original file is located at6https://colab.research.google.com/drive/1jlQnVtVH4-eqR6L8ny4iNG5_7C32wCUj7"""89'''10Implements inference in a Bayes net using autodiff applied to Z=einsum(factors).11[email protected], April 201912Based on "A differential approach to inference in Bayesian networks"13Adnan Darwiche, JACM 2003.14Cached copy:15https://github.com/probml/pyprobml/blob/master/data/darwiche-acm2003-differential.pdf16A similar result is shown in17"Inside-outside and forward-backward algorithms are just backprop",18Jason Eisner (2016).19EMNLP Workshop on Structured Prediction for NLP.20http://cs.jhu.edu/~jason/papers/eisner.spnlp16.pdf2122The idea of using einsum instead of an arithmetic circuit to compute Z23is based on24"Tensor Variable Elimination for Plated Factor Graphs"25Fritz Obermeyer, Eli Bingham, Martin Jankowiak, Justin Chiu, Neeraj Pradhan, Alexander Rush, Noah Goodman26Arxiv 2019 (https://arxiv.org/abs/1902.03210)272829For a demo of how to use this code, see30https://github.com/probml/pyprobml/blob/master/book/student_pgm_inf_autodiff.py31For a unit test see32https://github.com/probml/pyprobml/blob/master/book/bayesnet_inf_autodiff_test.py3334Darwiche defines the network polynomial f(l) = poly(l, theta),35where l(i,j)=1 if variable i is in state j; these36are called the evidence vectors, denoted by lambda.37Let e be a vector of observations for a (sub)set of the nodes.38Let o(i)=1 if variable i is observed in e, and o(i)=0 otherwise.39So l(i,j)=ind{j=e(i)) if o(i)=1, and l(i,:)=1 otherwise.40Thus l(i,j)=1 means the setting x(i)=j is compatible with e.41Define f(e) = f(l(e)), where l(e) is this binding process.42Thm 1: f(l(e)) = Pr(x(o)=e) = Pr(e), the probability of the evidence.43f(e) is also denoted by Z, the normalization constant in Bayes rule.44Note that we can compute f(e) using einstein summation over all terms45in the network poly.46Thm 2: let g_{i,j}(l)= d/dl(i,j) f(l(e)) be partial derivative.47so g_i(l) is the gradient vector for variable i.48Then g_ij(l) = Pr(x(i)=j, x(o(-i))=e(-i)), where o(-i) are all observed49variables except i.50Corollary 1. d/dl(i,j) log f(l(e)) = 1/f(l(e)) * g_{ij}(l(e))51= Pr(x(i)=j | e) if o(i)=0 (so i not in e).52This is the standard result that derivatives of the log partition function53gives the expected sufficient statistics, which for a multinomial54are the posterior marignals over states.55We use jax (https://github.com/google/jax) to compute the partial56derivatives. This requires that f(e) be implemented using jax's57version of einsum, which fortunately is 100% compatible with the numpy58version.59'''606162import superimport6364import numpy as np # original numpy65import jax.numpy as jnp66from jax import grad, jit, vmap67from jax.ops import index, index_add, index_update68from functools import partial6970def make_einsum_string(dag):71# example: if dag is B <- A -> C, returns 'A,B,C, A,BA,CA->'72node_names = list(dag.keys())73cpt_names = [n + ''.join(dag[n]) for n in node_names] # indices for CPTs74str = ','.join(node_names) + ',' + ','.join(cpt_names) + '->'75return str7677def make_list_of_factors(dag, params, evectors):78# Extract dictionary elements in same order as einsum string79node_names = list(dag.keys())80evecs = []81cpts = []82for n in node_names:83evecs.append(evectors[n])84cpts.append(params[n])85return (evecs+cpts)8687def network_poly(dag, params, evectors, elim_order=None):88# Sum over all assignments to network polynomial to compute Z89str = make_einsum_string(dag)90factors = make_list_of_factors(dag, params, evectors)91if elim_order is None:92return jnp.einsum(str, *factors)93else:94return jnp.einsum(str, *factors, optimize=elim_order)9596def make_evidence_vectors(cardinality, evidence):97# compute l(i,j)=1 iff x(i)=j is compatible with evidence e98def f(nstates, val):99if val == -1:100vec = jnp.ones(nstates)101else:102#vec[val] = 1.0 # not allowed to mutate state in jax103vec = index_update(np.zeros(nstates), index[val], 1.0) # functional assignment104return vec105return {name: f(nstates, evidence.get(name, -1)) for name, nstates in cardinality.items()}106107def marginal_probs(dag, params, evidence, elim_order=None):108# Compute marginal probabilities of all nodes in a Bayesnet.109cardinality = {name: jnp.shape(CPT)[0] for name, CPT in params.items()}110evectors = make_evidence_vectors(cardinality, evidence)111f = lambda ev: network_poly(dag, params, ev, elim_order) # clamp model parameters112prob_ev = f(evectors)113grads = grad(f)(evectors) # list of derivatives wrt evectors114probs = dict()115for name in dag.keys():116ev = evidence.get(name, -1)117if ev == -1: # not observed118probs[name] = grads[name] / prob_ev119else:120probs[name] = evectors[name] # clamped node121probs[name] = np.array(probs[name]) # cast back to vanilla numpy array122return prob_ev, probs123124125126def compute_elim_order(dag, params):127# compute optimal elimination order assuming no nodes are observed128evidence = {}129cardinality = {name: jnp.shape(CPT)[0] for name, CPT in params.items()}130evectors = make_evidence_vectors(cardinality, evidence)131str = make_einsum_string(dag)132factors = make_list_of_factors(dag, params, evectors)133nnodes = len(dag.keys())134#print('computing elimination order for DAG with {} nodes'.format(nnodes))135#elim_order = jnp.einsum_path(str, *factors, optimize='optimal')[0]136elim_order = jnp.einsum_path(str, *factors, optimize='greedy')[0]137return elim_order138139# Class that provides syntactic sugar on top of above functions.140class BayesNetInfAutoDiff:141def __init__(self, dag, params):142self._dag = dag143self._params = params144self._elim_order = compute_elim_order(dag, params)145#self._elim_order = None146147def infer_marginals(self, evidence):148prob_ev, marginals = marginal_probs(self._dag, self._params, evidence,149self._elim_order)150return marginals151152