Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
probml
GitHub Repository: probml/pyprobml
Path: blob/master/notebooks/misc/asia_pgm.ipynb
1192 views
Kernel: Python 3

Open In Colab

!pip install pgmpy # https://github.com/pgmpy/pgmpy#installation
Collecting pgmpy Downloading pgmpy-0.1.18-py3-none-any.whl (1.9 MB) |████████████████████████████████| 1.9 MB 5.0 MB/s Requirement already satisfied: tqdm in /usr/local/lib/python3.7/dist-packages (from pgmpy) (4.64.0) Requirement already satisfied: scipy in /usr/local/lib/python3.7/dist-packages (from pgmpy) (1.4.1) Requirement already satisfied: networkx in /usr/local/lib/python3.7/dist-packages (from pgmpy) (2.6.3) Requirement already satisfied: torch in /usr/local/lib/python3.7/dist-packages (from pgmpy) (1.10.0+cu111) Requirement already satisfied: scikit-learn in /usr/local/lib/python3.7/dist-packages (from pgmpy) (1.0.2) Requirement already satisfied: statsmodels in /usr/local/lib/python3.7/dist-packages (from pgmpy) (0.10.2) Requirement already satisfied: joblib in /usr/local/lib/python3.7/dist-packages (from pgmpy) (1.1.0) Requirement already satisfied: pyparsing in /usr/local/lib/python3.7/dist-packages (from pgmpy) (3.0.8) Requirement already satisfied: pandas in /usr/local/lib/python3.7/dist-packages (from pgmpy) (1.3.5) Requirement already satisfied: numpy in /usr/local/lib/python3.7/dist-packages (from pgmpy) (1.21.6) Requirement already satisfied: python-dateutil>=2.7.3 in /usr/local/lib/python3.7/dist-packages (from pandas->pgmpy) (2.8.2) Requirement already satisfied: pytz>=2017.3 in /usr/local/lib/python3.7/dist-packages (from pandas->pgmpy) (2022.1) Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.7/dist-packages (from python-dateutil>=2.7.3->pandas->pgmpy) (1.15.0) Requirement already satisfied: threadpoolctl>=2.0.0 in /usr/local/lib/python3.7/dist-packages (from scikit-learn->pgmpy) (3.1.0) Requirement already satisfied: patsy>=0.4.0 in /usr/local/lib/python3.7/dist-packages (from statsmodels->pgmpy) (0.5.2) Requirement already satisfied: typing-extensions in /usr/local/lib/python3.7/dist-packages (from torch->pgmpy) (4.1.1) Installing collected packages: pgmpy Successfully installed pgmpy-0.1.18
# Fetching the network !wget http://www.bnlearn.com/bnrepository/asia/asia.bif.gz !gzip -qd asia.bif.gz !ls
--2022-04-22 22:02:00-- http://www.bnlearn.com/bnrepository/asia/asia.bif.gz Resolving www.bnlearn.com (www.bnlearn.com)... 176.58.124.98 Connecting to www.bnlearn.com (www.bnlearn.com)|176.58.124.98|:80... connected. HTTP request sent, awaiting response... 301 Moved Permanently Location: https://www.bnlearn.com/bnrepository/asia/asia.bif.gz [following] --2022-04-22 22:02:00-- https://www.bnlearn.com/bnrepository/asia/asia.bif.gz Connecting to www.bnlearn.com (www.bnlearn.com)|176.58.124.98|:443... connected. HTTP request sent, awaiting response... 200 OK Length: 310 [application/gzip] Saving to: ‘asia.bif.gz’ asia.bif.gz 100%[===================>] 310 --.-KB/s in 0s 2022-04-22 22:02:01 (36.8 MB/s) - ‘asia.bif.gz’ saved [310/310] asia.bif sample_data
import numpy as np from pgmpy.readwrite import BIFReader reader = BIFReader("asia.bif") asia_model = reader.get_model() asia_model.nodes() asia_model.edges() CPDs = asia_model.get_cpds() # Doing exact inference using Variable Elimination from pgmpy.inference import VariableElimination asia_infer = VariableElimination(asia_model) # Computing the probability of bronc given smoke. q = asia_infer.query(variables=["bronc"], evidence={"smoke": 0}).values print("p(bronchitis | smoke=0", q)
/usr/local/lib/python3.7/dist-packages/pgmpy/factors/discrete/DiscreteFactor.py:537: UserWarning: Found unknown state name. Trying to switch to using all state names as state numbers UserWarning,
0it [00:00, ?it/s]
0it [00:00, ?it/s]
p(bronchitis | smoke=0 [0.6 0.4]
""" Sanity check. p(A=t|T=t) = p(A=t) p(T=t|A=t) / [ p(A=t) p(T=t|A=t) + p(A=f) p(T=t|A=f)] = 0.01 * 0.05 / (0.01 * 0.05 + 0.99 * 0.01) = 0.0481 """ # 0 = True. 1 = False q = asia_infer.query(variables=["asia"], evidence={"tub": 0}).values print("p(asia | tb=1)", q) assert np.allclose(q[0], 0.04, atol=1e-1)
/usr/local/lib/python3.7/dist-packages/pgmpy/factors/discrete/DiscreteFactor.py:537: UserWarning: Found unknown state name. Trying to switch to using all state names as state numbers UserWarning,
0it [00:00, ?it/s]
0it [00:00, ?it/s]
p(asia | tb=1) [0.04807692 0.95192308]