Bayesnet example using Pomegranate
07 Apr 2016Example taken from Coursera's Probabilistic Graphical Models
, taught by Prof. Daphne Koller (https://class.coursera.org/pgm/)
Pomegranate library: https://github.com/jmschrei/pomegranate
In [1]:
import pomegranate as pom
In [2]:
difficulty = pom.DiscreteDistribution({ 'hard':0.6, 'easy':0.4 })
intelligence = pom.DiscreteDistribution({ 'intelligent':0.7, 'not-intelligent':0.3 })
grade = pom.ConditionalProbabilityTable(
[['hard', 'intelligent', 'A', 0.3 ],
['hard', 'intelligent', 'B', 0.4 ],
['hard', 'intelligent', 'C', 0.3 ],
['hard', 'not-intelligent', 'A', 0.05 ],
['hard', 'not-intelligent', 'B', 0.25 ],
['hard', 'not-intelligent', 'C', 0.7 ],
['easy', 'intelligent', 'A', 0.9 ],
['easy', 'intelligent', 'B', 0.08 ],
['easy', 'intelligent', 'C', 0.02 ],
['easy', 'not-intelligent', 'A', 0.5 ],
['easy', 'not-intelligent', 'B', 0.3 ],
['easy', 'not-intelligent', 'C', 0.2 ],
],
[difficulty, intelligence]
)
sat = pom.ConditionalProbabilityTable(
[['intelligent', 'high-mark', 0.95],
['intelligent', 'low-mark', 0.05],
['not-intelligent', 'high-mark', 0.2],
['not-intelligent', 'low-mark', 0.8]
],
[intelligence]
)
letter = pom.ConditionalProbabilityTable(
[['A', 'reference-letter', 0.9],
['B', 'reference-letter', 0.6],
['C', 'reference-letter', 0.01],
['A', 'no-reference-letter', 0.1],
['B', 'no-reference-letter', 0.4],
['C', 'no-reference-letter', 0.99]
],
[grade]
)
In [3]:
d = pom.State(difficulty, name="difficulty")
i = pom.State(intelligence, name="intelligence")
g = pom.State(grade, name="grade")
s = pom.State(sat, name="sat")
l = pom.State(letter, name="letter")
In [4]:
network = pom.BayesianNetwork( "student" )
network.add_states( [ d, i, g, s, l ] )
In [5]:
network.add_transition( d, g )
network.add_transition( i, g )
network.add_transition( i, s )
network.add_transition( g, l )
In [6]:
network.bake()
Observe how the probabilites are effected by observing various outcomes:
In [7]:
observations = { 'difficulty': 'hard', 'sat':'high-mark' }
beliefs = map( str, network.predict_proba( observations ) )
print "\n".join( "{}\t{}".format( state.name, belief ) for state, belief in zip( network.states, beliefs ) )