Probabilistic Programming with Pyro in WML

Share this post:

In a previous post we explained how to write a probabilistic model using Edward and run it on the IBM Watson Machine Learning (WML) platform. In this post, we discuss the same example written in Pyro, a deep probabilistic programming language built on top of PyTorch.

Deep probabilistic programming languages (DPPLs) such as Edward and Pyro aim to combine the advantages of probabilistic programming languages (i.e., intuitive formalism and dedicated constructs to build probabilistic models) and deep learning frameworks (i.e., the ability to write, train, and deploy DL models) to build advanced probabilistic models.

This post illustrates, with a simple example, how to use Pyro and the IBM Watson Machine Learning (WML) platform to write and train a simple but complete probabilistic model involving a deep learning network. WML is a cloud service that allows developers to efficiently train, deploy, and monitor machine learning models on fast GPUs. You can download the complete code here.

Quick start with WML

Pyro is now available in WML with PyTorch 0.4. You can now write Pyro code and run it on GPUs using WML. Before we start, you need a ready-to-use WML environment, that is, access to WML, and a Cloud Object Storage service.

Your configuration file manifest.yml should look something like this (filled with your object storage credentials):

    name: Me
  description: Simple MLP in Pyro for classifying MNIST
    command: python
      name: k80
    name: pytorch
    version: '0.4'
  name: pyro_mnist_mlp
    access_key_id: xxxxxxxxxxxxxxx
    secret_access_key: xxxxxxxxxxxxxxx
  name: training_data_reference_name
    bucket: xxxxxxxxxxxxxxx
  type: s3
    access_key_id: xxxxxxxxxxxxxxx
    secret_access_key: xxxxxxxxxxxxxxx
  name: training_results_reference_name
    bucket: xxxxxxxxxxxxxxx
  type: s3

Bayesian MLP in Pyro

Our previous post explained how to write a simple probabilistic multi-layer perceptron (MLP) for classifying hand-written digits in Edward. The main idea is to treat all the weights and biases of the network as random variables. We thus learn a complete distribution for each parameter of the network. These distributions can be used to measure the uncertainty associated to the ouptut of the network, which can be critical for decision-making systems. As in Edward, we will use variational inference to learn the distributions in Pyro.

In short, we need to define two main components:

  1. The probabilistic model: a MLP where all weights and biases are treated as random variable; and
  2. A family of guide distributions for the variational inference.

The corresponding Pyro code is the following:

class MLP(nn.Module):
    def __init__(self):
        super(MLP, self).__init__()
        self.l1 = torch.nn.Linear(nx, nh)
        self.l2 = torch.nn.Linear(nh, ny)
        self.relu = torch.nn.ReLU()

    def forward(self, x):
        h = self.relu(self.l1(x.view((-1, nx))))
        yhat = self.l2(h)
        return yhat

mlp = MLP().to(device)


def normal(*shape):
    loc = torch.zeros(*shape).to(device)
    scale = torch.ones(*shape).to(device)
    return Normal(loc, scale)

def model(imgs, lbls):
    priors = {
        'l1.weight': normal(nh, nx), 'l1.bias': normal(nh),
        'l2.weight': normal(ny, nh), 'l2.bias': normal(ny)}
    lifted_module = pyro.random_module("mlp", mlp, priors)
    lifted_reg_model = lifted_module()
    lhat = log_softmax(lifted_reg_model(imgs))
    pyro.sample("obs", Categorical(logits=lhat), obs=lbls)

Inference guide

def vnormal(name, *shape):
    loc = pyro.param(name + "m",
                     torch.randn(*shape, requires_grad=True, device=device))
    scale = pyro.param(name + "s",
                       torch.randn(*shape, requires_grad=True, device=device))
    return Normal(loc, softplus(scale))

def guide(imgs, lbls):
    dists = {
        'l1.weight': vnormal("W1", nh, nx), 'l1.bias': vnormal("b1", nh),
        'l2.weight': vnormal("W2", ny, nh), 'l2.bias':vnormal("b2", ny)}
    lifted_module = pyro.random_module("mlp", mlp, dists)
    return lifted_module()

The MLP network is defined in PyTorch. In the model, we first define the prior distributions for all the weights and biases and then lift the MLP definition from concrete to probabilistic using the pyro.random_module function. The result yhat parameterizes a categorical distribution over the possible labels for an image x. Note the pyro.observe statement that will match the prediction of the network yhat with the known label y during the inference.

The guide defines the family of distributions used for variational inference. In our case all the parameters follow a normal distribution. Note the use of pyro.param to define the parameters of the family (here the means and scale of the normal distribution). After the training these variables contain the parameters of the distribution that is the closest to the true posterior.


Before starting the training, let us import the MNIST dataset.

train = MNIST("MNIST", train=True, download=True,
              transform=transforms.Compose([transforms.ToTensor(), ]), )
test = MNIST("MNIST", train=False, download=True,
             transform=transforms.Compose([transforms.ToTensor(), ]), )
dataloader_args = dict(shuffle=True, batch_size=batch_size,
                       num_workers=1, pin_memory=False)
train_loader = dataloader.DataLoader(train, **dataloader_args)
test_loader = dataloader.DataLoader(test, **dataloader_args)


We can now launch the inference.

inference = SVI(model, guide, Adam({"lr": 0.01}), loss=Trace_ELBO())

for epoch in range(num_epochs):
    for j, (imgs, lbls) in enumerate(train_loader, 0):
        loss = inference.step(,


When the training is complete, we can sample the guide containing the posterior distribution multiple times to obtain a set of MLPs. We can then combine the predictions of all the MLPs to compute a prediction.

def predict(x):
    sampled_models = [guide(None, None) for _ in range(num_samples)]
    yhats = [model(x).data for model in sampled_models]
    mean = torch.mean(torch.stack(yhats), 0)
    return np.argmax(mean, axis=1)

correct = 0
total = 0
for j, data in enumerate(test_loader):
    images, labels = data
    predicted = predict(
    total += labels.size(0)
    correct += (predicted == labels).sum().item()
print("accuracy: {ccf696850f4de51e8cea028aa388d2d2d2eef894571ad33a4aa3b26b43009887}d {ccf696850f4de51e8cea028aa388d2d2d2eef894571ad33a4aa3b26b43009887}{ccf696850f4de51e8cea028aa388d2d2d2eef894571ad33a4aa3b26b43009887}" {ccf696850f4de51e8cea028aa388d2d2d2eef894571ad33a4aa3b26b43009887} (100 * correct / total))

That’s it! You can now export the following environment variables with your WML credentials:

export ML_ENV=xxxxxxxxxxxxxxx
export ML_INSTANCE=xxxxxxxxxxxxxxx
export ML_USERNAME=xxxxxxxxxxxxxxx
export ML_PASSWORD=xxxxxxxxxxxxxxx

and run:

bx ml manifest.yml

where is an archive containing all the python source files (e.g.,, data-loaders, etc…). This command returns an id (e.g., training-xxxxxxxxx) that you can use to monitor the runnning job:

bx ml monitor training-runs training-xxxxxxxxx

Further reading on Pyro and DPPLs

More examples in both Edward and Pyro can be found in this paper.

Research Staff Member, IBM Research

Benjamin Herta

IBM Research

More AI stories

IBM Sets New Transcription Performance Milestone on Automatic Broadcast News Captioning

IBM sets new performance records for automatic captioning of broadcast news audio, with error rates of 6.5% and 5.9% on two broadcast news benchmarks.

Continue reading

Leveraging Temporal Dependency to Combat Audio Adversarial Attacks

A new approach to defend against adversarial attacks in non-image tasks, such as audio input and automatic speech recognition.

Continue reading

Unifying Continual Learning and Meta-Learning with Meta-Experience Replay

Meta-Experience Replay (MER) integrates meta-learning and experience replay to achieve state-of-the-art performance on continual learning benchmarks.

Continue reading