Convex Optimized

Building a probabilistic programming interpreter

by Rob Zinkov on 2015-08-25

Very often interpreters for probabilisitic programming languages (PPLs) can seem a little mysterious. In actuality, if you know how to write an interpreter for a simple language it isn’t that much more work.

Using Haskell as the host language I’ll show how to write a simple PPL which uses importance sampling as the underlying inference method. There is nothing special about using from haskell other than pattern-matching so this example should be pretty easy to port to other languages.

To start let’s import some things and set up some basic types

import Data.List hiding (empty, insert, map)
import Control.Monad

import Data.HashMap.Strict hiding (map)
import System.Random.MWC as MWC
import System.Random.MWC.Distributions as MD

type Name = String
type Env  = HashMap String Val

Our language will have as values functions, doubles, bools and pairs of those.

data Val =
    D Double |
    B Bool   |
    F (Val -> Val) |
    P Val Val

instance Eq Val where
  D x == D y         = x == y
  B x == B y         = x == y
  P x1 x2 == P y1 y2 = x1 == y1 && x2 == y2
  _ == _             = False

instance Ord Val where
  D x <= D y         = x <= y
  B x <= B y         = x <= y
  P x1 x2 <= P y1 y2 = x1 <= y1 && x2 <= y2
  _ <= _             = error "Comparing functions is undefined"

This language will have expressions for these values, conditionals and arithmetic.

data Expr =
     Lit Double |
     Var Name |
     Pair Expr Expr |
     Fst Expr |
     Snd Expr |
     If  Expr Expr Expr |

     Eql Expr Expr |
     Les Expr Expr |
     Gre Expr Expr |
     And Expr Expr |

     Lam Name Expr |
     App Expr Expr |

     Add Expr Expr |
     Sub Expr Expr |
     Mul Expr Expr |
     Div Expr Expr
 deriving (Eq, Show)

We can evalute expressions in this language without doing anything special.

evalT :: Expr -> Env -> Val
evalT (Lit a) _            = D a
evalT (Var x)      env     = env ! x
evalT (Lam x body) env     = F (\ x' -> evalT body (insert x x' env))
evalT (App f x)    env     = app (evalT f env) (evalT x env)
           
evalT (Eql x y)    env     = B $ (evalT x env) == (evalT y env)
evalT (Les x y)    env     = B $ (evalT x env) <= (evalT y env)
evalT (Gre x y)    env     = B $ (evalT x env) >= (evalT y env)
evalT (And x y)    env     = liftB (&&) (evalT x env) (evalT y env)
                
evalT (Add x y)    env     = liftOp (+) (evalT x env) (evalT y env)
evalT (Sub x y)    env     = liftOp (-) (evalT x env) (evalT y env)
evalT (Mul x y)    env     = liftOp (*) (evalT x env) (evalT y env)
evalT (Div x y)    env     = liftOp (/) (evalT x env) (evalT y env)
                           
evalT (Pair x y)   env     = P (evalT x env) (evalT y env)
evalT (Fst x)      env     = fst_ $ evalT x env
 where fst_ (P a b)        = a
evalT (Snd x)      env     = snd_ $ evalT x env
 where snd_ (P a b)        = b
evalT (If b t f)   env     = if_ (evalT b env) (evalT t env) (evalT f env)
 where if_ (B True)  t' f' = t'
       if_ (B False) t' f' = f'

app :: Val -> Val -> Val
app (F f') x'   = f' x'

liftOp :: (Double -> Double -> Double) ->
          Val     -> Val    -> Val
liftOp op (D e1) (D e2) = D (op e1 e2)

liftB  :: (Bool -> Bool -> Bool) ->
          Val     -> Val    -> Val
liftB  op (B e1) (B e2) = B (op e1 e2)

Of course this isn’t a probabilisitic programming language. So now we extend our language to include measures.

data Meas =
     Uniform Expr Expr |
     Weight  Expr Expr |
     Bind Name Meas Meas
 deriving (Eq, Show)

Let’s take a moment to explain what makes something a measure. Measures can considered un-normalized probability distributions. If you take the sum of the probability of each disjoint outcome from a un-normalized probability distribution, the answer may not be 1.

This is relevant as we will be representing measures as a list of weighted draws from the underlying distribution. Those draws will need to be normalized to be understood as a probability distribution.

We can construct measures in one of three ways. We may simply have the continuous uniform distribution whose bounds are defined as expressions. We may have a weighted distribution which only returns the value of its second argument, with probability of the first argument. This is only a probability distribution when the first argument evaluates to one. We’ll call this case dirac

dirac :: Expr -> Meas
dirac x = Weight (Lit 1.0) x

The final form is what let’s us build measure expressions. What Bind does is take a measure as input, and a function from draws in that measure to another measure.

Because I don’t feel like defining measurable functions in their own form, Bind also takes a name to set what variable will hold values forthe draws, so the last argument to bind may just use that variable when it wants to refer to those draws. As an example if I wish to take a draw from a uniform distribution and then square that value.

prog1 = Bind "x" (Uniform (Lit 1) (Lit 5))   -- x <~ uniform(1, 5)
        (dirac (Add (Var "x") (Var "x")))   -- return (x + x)

Measures are evaluated by producing a weighted sample from the measure space they represent. This is also called importance sampling.

evalM :: Meas -> Env -> MWC.GenIO -> IO (Val, Double)
evalM (Uniform lo hi) env g = do
                              let D lo' = evalT lo env
                              let D hi' = evalT hi env
                              x <- MWC.uniformR (lo', hi') g
                              return (D x, 1.0)
evalM (Weight i x)    env g = do
                              let D i' = evalT i env
                              return (evalT x env, i')
evalM (Bind x m f)    env g = do
                              (x', w)  <- evalM m env g
                              let env' = insert x x' env
                              (f', w1) <- evalM f env' g
                              return (f', w*w1)

We may run these programs as follows

test1 :: IO ()
test1 = do
   g <- MWC.create
   draw <- evalM prog1 empty g
   print draw

(7.926912543562406,1.0)

Evaluating this program repeatedly will allow you to produce as many draws from this measure as you need. This is great in that we can represent any unconditioned probability distribution. But how do we represent conditional distributions?

For that we will introduce another datatype

data Cond =
    UCond Meas |
    UniformC Expr Expr Expr |
    WeightC  Expr Expr Expr |
    BindC Name Cond Cond

This is just an extension of Meas expect now we may say, a measure is either unconditioned, or if its conditioned for each case we may specify additionally which value its conditioned on. To draw from a conditioned measure, we convert it into an unconditional measure.

evalC :: Cond -> Meas
evalC (UCond    m      ) = m
evalC (UniformC lo hi x) = Weight (If (And (Gre x lo)
                                                 (Les x hi))
                                         (Div x (Sub hi lo))
                                         (Lit 0)) x
evalC (WeightC  i x   y) = Weight (If (Eql x y)
                                         i
                                         (Lit 0)) y
evalC (BindC    x m f)   = Bind x (evalC m) (evalC f)

What evalC does is determine what weight to assign to a measure given we know it will produce a particular value. This weight is the probability of getting this value from the measure.

And that’s all you need to express probabilisitic programs. Take the following example. Suppose we have two random variables x and y where the value of y depends on x

x <~ uniform(1, 5)
y <~ uniform(x, 7)

What’s the conditional distribution on x given y is 3?

prog2 = BindC "x" (UCond (Uniform (Lit 1) (Lit 5)))      -- x <~ uniform(1, 5)
         (BindC "_" (UniformC (Var "x") (Lit 7) (Lit 3)) -- y <~ uniform(x, 7)
                                                         -- observe y 3
          (UCond (dirac (Var "x"))))                     -- return x

test2 :: IO ()
test2 = do
   g <- MWC.create
   samples <- replicateM 10 (evalM (evalC prog2) empty g)
   print samples

[(1.099241451531848, 0.5084092113511076),
 (3.963456271781203, 0.0),
 (1.637454187135532, 0.5594357800735532),
 (3.781075065891581, 0.0),
 (1.908186342514358, 0.5891810269980327),
 (2.799366130116895, 0.714177929552209),
 (3.091757816253942, 0.0),
 (1.486166046469419, 0.5440860253107659),
 (3.106369061983323, 0.0),
 (1.225163855492708, 0.5194952592470413)]

As you can see, anything above 3 for x has a weight of 0 because it would be impossible for to observe y with 3.

Further reading

This implementation for small problems is actually fairly capable. It can be extended to support more probability distributions in a straightforward way.

If you are interested in more advanced interpreters I suggest reading the following.