Probabilistic Programming with Pyro in WML

Share this post:

In a previous post we explained how to write a probabilistic model using Edward and run it on the IBM Watson Machine Learning (WML) platform. In this post, we discuss the same example written in Pyro, a deep probabilistic programming language built on top of PyTorch.

Deep probabilistic programming languages (DPPLs) such as Edward and Pyro aim to combine the advantages of probabilistic programming languages (i.e., intuitive formalism and dedicated constructs to build probabilistic models) and deep learning frameworks (i.e., the ability to write, train, and deploy DL models) to build advanced probabilistic models.

This post illustrates, with a simple example, how to use Pyro and the IBM Watson Machine Learning (WML) platform to write and train a simple but complete probabilistic model involving a deep learning network. WML is a cloud service that allows developers to efficiently train, deploy, and monitor machine learning models on fast GPUs. You can download the complete code here.

Quick start with WML

Pyro is now available in WML with PyTorch 0.4. You can now write Pyro code and run it on GPUs using WML. Before we start, you need a ready-to-use WML environment, that is, access to WML, and a Cloud Object Storage service.

Your configuration file manifest.yml should look something like this (filled with your object storage credentials):

    name: Me
  description: Simple MLP in Pyro for classifying MNIST
    command: python
      name: k80
    name: pytorch
    version: '0.4'
  name: pyro_mnist_mlp
    access_key_id: xxxxxxxxxxxxxxx
    secret_access_key: xxxxxxxxxxxxxxx
  name: training_data_reference_name
    bucket: xxxxxxxxxxxxxxx
  type: s3
    access_key_id: xxxxxxxxxxxxxxx
    secret_access_key: xxxxxxxxxxxxxxx
  name: training_results_reference_name
    bucket: xxxxxxxxxxxxxxx
  type: s3

Bayesian MLP in Pyro

Our previous post explained how to write a simple probabilistic multi-layer perceptron (MLP) for classifying hand-written digits in Edward. The main idea is to treat all the weights and biases of the network as random variables. We thus learn a complete distribution for each parameter of the network. These distributions can be used to measure the uncertainty associated to the ouptut of the network, which can be critical for decision-making systems. As in Edward, we will use variational inference to learn the distributions in Pyro.

In short, we need to define two main components:

  1. The probabilistic model: a MLP where all weights and biases are treated as random variable; and
  2. A family of guide distributions for the variational inference.

The corresponding Pyro code is the following:

class MLP(nn.Module):
    def __init__(self):
        super(MLP, self).__init__()
        self.l1 = torch.nn.Linear(nx, nh)
        self.l2 = torch.nn.Linear(nh, ny)
        self.relu = torch.nn.ReLU()

    def forward(self, x):
        h = self.relu(self.l1(x.view((-1, nx))))
        yhat = self.l2(h)
        return yhat

mlp = MLP().to(device)


def normal(*shape):
    loc = torch.zeros(*shape).to(device)
    scale = torch.ones(*shape).to(device)
    return Normal(loc, scale)

def model(imgs, lbls):
    priors = {
        'l1.weight': normal(nh, nx), 'l1.bias': normal(nh),
        'l2.weight': normal(ny, nh), 'l2.bias': normal(ny)}
    lifted_module = pyro.random_module("mlp", mlp, priors)
    lifted_reg_model = lifted_module()
    lhat = log_softmax(lifted_reg_model(imgs))
    pyro.sample("obs", Categorical(logits=lhat), obs=lbls)

Inference guide

def vnormal(name, *shape):
    loc = pyro.param(name + "m",
                     torch.randn(*shape, requires_grad=True, device=device))
    scale = pyro.param(name + "s",
                       torch.randn(*shape, requires_grad=True, device=device))
    return Normal(loc, softplus(scale))

def guide(imgs, lbls):
    dists = {
        'l1.weight': vnormal("W1", nh, nx), 'l1.bias': vnormal("b1", nh),
        'l2.weight': vnormal("W2", ny, nh), 'l2.bias':vnormal("b2", ny)}
    lifted_module = pyro.random_module("mlp", mlp, dists)
    return lifted_module()

The MLP network is defined in PyTorch. In the model, we first define the prior distributions for all the weights and biases and then lift the MLP definition from concrete to probabilistic using the pyro.random_module function. The result yhat parameterizes a categorical distribution over the possible labels for an image x. Note the pyro.observe statement that will match the prediction of the network yhat with the known label y during the inference.

The guide defines the family of distributions used for variational inference. In our case all the parameters follow a normal distribution. Note the use of pyro.param to define the parameters of the family (here the means and scale of the normal distribution). After the training these variables contain the parameters of the distribution that is the closest to the true posterior.


Before starting the training, let us import the MNIST dataset.

train = MNIST("MNIST", train=True, download=True,
              transform=transforms.Compose([transforms.ToTensor(), ]), )
test = MNIST("MNIST", train=False, download=True,
             transform=transforms.Compose([transforms.ToTensor(), ]), )
dataloader_args = dict(shuffle=True, batch_size=batch_size,
                       num_workers=1, pin_memory=False)
train_loader = dataloader.DataLoader(train, **dataloader_args)
test_loader = dataloader.DataLoader(test, **dataloader_args)


We can now launch the inference.

inference = SVI(model, guide, Adam({"lr": 0.01}), loss=Trace_ELBO())

for epoch in range(num_epochs):
    for j, (imgs, lbls) in enumerate(train_loader, 0):
        loss = inference.step(,


When the training is complete, we can sample the guide containing the posterior distribution multiple times to obtain a set of MLPs. We can then combine the predictions of all the MLPs to compute a prediction.

def predict(x):
    sampled_models = [guide(None, None) for _ in range(num_samples)]
    yhats = [model(x).data for model in sampled_models]
    mean = torch.mean(torch.stack(yhats), 0)
    return np.argmax(mean, axis=1)

correct = 0
total = 0
for j, data in enumerate(test_loader):
    images, labels = data
    predicted = predict(
    total += labels.size(0)
    correct += (predicted == labels).sum().item()
print("accuracy: {ccf696850f4de51e8cea028aa388d2d2d2eef894571ad33a4aa3b26b43009887}d {ccf696850f4de51e8cea028aa388d2d2d2eef894571ad33a4aa3b26b43009887}{ccf696850f4de51e8cea028aa388d2d2d2eef894571ad33a4aa3b26b43009887}" {ccf696850f4de51e8cea028aa388d2d2d2eef894571ad33a4aa3b26b43009887} (100 * correct / total))

That’s it! You can now export the following environment variables with your WML credentials:

export ML_ENV=xxxxxxxxxxxxxxx
export ML_INSTANCE=xxxxxxxxxxxxxxx
export ML_USERNAME=xxxxxxxxxxxxxxx
export ML_PASSWORD=xxxxxxxxxxxxxxx

and run:

bx ml manifest.yml

where is an archive containing all the python source files (e.g.,, data-loaders, etc…). This command returns an id (e.g., training-xxxxxxxxx) that you can use to monitor the runnning job:

bx ml monitor training-runs training-xxxxxxxxx

Further reading on Pyro and DPPLs

More examples in both Edward and Pyro can be found in this paper.

Research Staff Member, IBM Research

Benjamin Herta

IBM Research

More AI stories

AI Enables Foreign Language Study Abroad, No Travel Required

A student learning to speak Mandarin wanders into a marketplace on the streets of China on a sunny summer afternoon. Before long, two vendors approach and begin hawking products, trying to outbid one another. The student must now grasp what’s being said and formulate an appropriate response using proper pronunciation to avoid being misunderstood. It’s […]

Continue reading

Text2Scene: Generating Compositional Scenes from Textual Descriptions

At CVPR 2019, IBM researchers introduce techniques to interpret visually descriptive language to generate compositional scene representations from textual descriptions.

Continue reading

Overcoming Challenges In Automated Image Captioning

At CVPR 2019, IBM researchers introduce an improved method to bridge the semantic gap between visual scenes and language to produce diverse, creative and human-like captions.

Continue reading