Elastic distributed training

Use this document to learn how to run elastic distributed training workloads using GPU devices. GPU devices are dynamically allocated to the training job without stopping training.

Update model files

To run elastic distributed training, update your training model files to make the following two changes:

  1. Create a FabricModel instance to define an elastic distributed training model, see FabricModel definition for details on defining a FabricModel.
  2. To start training a model, run the train function of FabricModel, for details, see FabricModel methods. .

FabricModel definition

To utilize elastic distributed training, update your model to include the FabricModel definition.

FabricModel definition:
class FabricModel(model, datasets_function, loss_function, optimizer, metrics=None, worker_logger_callback=None, driver_logger_callback=None, save_to_onnx=False, keras_custom_objects=None, fn_batch_collate=None)
Parameters:
  • model: Required. The model instance, either an instance of tf.keras.model (TensorFlow) or torch.nn.Module (PyTorch)
  • datasets_function: Required. A python function which will return train and validation dataset. See Define dataset function
  • loss_function: Required. Loss function. For TensorFlow, it can be a tf.keras.losses instance or a string of the loss function name. For PyTorch, it must be a callable loss function.
  • optimizer: Required. Optimizer for model training. For TensorFlow, it can be a tf.keras.optimizers instance or a string of the loss function name. For PyTorch, it must be a torch.optim instance.
  • metrics: List of metrics to be evaluated by the model during training and testing. For TensorFlow, metrics can be a string (name of a built-in function), function or a tf.keras.metrics metric instance. For PyTorch, metrics can be a function or a callable instance.
  • worker_logger_callback: Optional. A logger callback to run in training worker. See Define custom logger callback.
  • driver_logger_callback: Optional. A logger callback to run in the driver worker. See Define custom logger callback.
  • save_to_onnx: Optional. Saves the final model in onnx format. Input: boolean
  • keras_custom_objects: Optional. Dictionary mapping names (strings) to custom classes or functions to be considered during deserialization.
  • fn_batch_collate: Optional. A custom function to collate batch from input dataset, see Define a custom batch collation function for how to define a collate function.

FabricModel methods

To start training a model, run the train function of FabricModel:
train(epoch_number, batch_size, engines_number=None, num_dataloader_threads=4, validation_freq=1, checkpoint_freq=1, effective_batch_size=None):
epoch_number
Required. Number of epochs to train the model. Must be an integer.
batch_size
Required. Local batch size to use per GPU during training. Must be an integer.
engines_number
Optional. Maximum number of GPUs to use during training. Must be an integer.
If not provided, batch_size * engines_number will be the global batch size to train before synchronization among workers.
num_dataloader_threads
Optional. Number of threads to load data batches for model training. Must be an integer.
validation_freq
Optional. Frequency between how many epochs to run model validation. Default is 1. Must be an integer.
checkpoint_freq
Optional. Frequency between how many epochs to save model checkpoint. Default 1. Must be an integer.
effective_batch_size
Optional. l, global batch size across all workers and it is exclusive with engines_number. Must be an integer.
When only effective_batch_size is specified, engines_number is effective_batch_size/batch_size.
When both engines_number and effective_batch_size is specified, engines_number uses the larger value.

Get started

With all the required changes to your model code, a simple training entry code can be as below:
from fabric_model import FabricModel

def get_dataset():
    # Prepare or clean data
    return train_dataset, test_dataset

model = ...
optimizer = ...
loss_function = ...

edt_model = FabricModel(model, get_dataset, loss_function, optimizer)

epochs = ...
batch_size = ...
engines_number = ...
edt_model.train(epochs, batch_size, engines_number)

References

Define dataset function

A dataset function is defined as follows:
def get_dataset():
    # Prepare or clean data
    return train_dataset, test_dataset
Both train and test dataset should be a map-style dataset, for PyTorch, it can be an instance of torch.utils.data.Dataset. for TensorFlow, it will be similar format as PyTorch map-style dataset with function __getitem__ and __len__ defined like below.
class EDTTensorFlowDataset:
    def __init__(self, x, y) -> None:

      self.x = np.array(x)
      self.y = np.array(y)

    def __getitem__(self, index):
        """Gets sample at position `index`.
        Args:
            index: position of the sample in data set.
        Returns:
            tuple (x, y)
            x - feature
            y - label
            x and y can be scalar, numpy array or a dict mapping names to the corresponding array or scalar.
        """
        return self.x[index], self.y[index]

    def __len__(self):
        """Number of samples in the dataset.
        Returns:
            The number of samples in the dataset.
        """
        return len(self.x)

Define custom logger callback

A custom logger can be defined as below and it can run on either as driver logger callback or worker logger callback. If a custom logger callback is used, the default logger callback is replaced with the custom logger callback.
class MyLoggerCallback():
    '''
    Abstract base class used to build new logger callbacks.
    '''

    def log_train_metrics(self, metrics, iteration, epoch, workers):
        '''
        Log metrics after training a batch.

        Parameters:
            metrics (dict): dictionary mapping metric names (strings) to their values. On driver, it will be the average accumulated
                            metric values among all batchs in current training epoch. On worker, it will be the on step training metrics.
            iteration (int): current iteration number. On driver, it will be the total number of iterations across all workers that have 
                             already been run. On worker, it will be the total number of iterations the particular worker has already run.
            epoch (int): current epoch number.
            workers (int): number of training workers or GPUs
        '''

    def log_test_metrics(self, metrics, iteration, epoch):
        '''
        Log metrics after training a batch.

        Parameters:
            metrics (dict): dictionary mapping metric names (strings) to their values. On driver, it will be the average accumulated
                            metric values among all test data. On worker, it will be the on step test metrics.
            iteration (int): current iteration number. On driver, it will always be 0. On worker, it will be the total number of 
                             test iterations the particular worker has already run.
            epoch (int): current epoch number.
        '''

    def on_train_end(self):
        '''
        Log metrics when training is finished.
        '''

    def on_train_begin(self):
        '''
        Log metrics when training is started.
        '''

Define a custom batch collation function

A custom batch collation function can be used to define how to combine samples from original dataset generated by the dataset function.
def batch_collate_function(batch)
    '''
    Parameters
    batch: a list of training samples, each sample is a tuple with one features value and one label value in it, basically what the __getitem__ of the dataset instance returns.
    Return:
A tuple (a batch of data features, a batch of data labels)
'''

Running elastic distributed training

There is not much difference to launch elastic distributed training compared with other regular training on Watson Machine Learning Accelerator, except it requires some particular values for training interfaces.

Using data connection through REST API

POST request to REST API /platform/rest/deeplearning/v1/execsto start training workload.

Using data connection through the command line interface (CLI)

Use the Watson Machine Learning Accelerator CLI. You can download the CLI from the Watson Machine Learning Accelerator console, see: dlicmd.py reference.

These are the primary parameters that could make a difference for elastic distributed training compared with other regular training on Watson Machine Learning Accelerator.
--exec-start <edtPyTorch for Pytorch and edtTensorflow for TensorFlow>
--numWorker <max number of worker node>
--workerDeviceNum <number of GPU device for each worker node>

When numWorker is larger than 1, if a new worker is added into the training, it will require last model weight from other workers, environment variable EDT_PULL_WEIGHT_TIME_OUT is defined the longest time the new worker will wait before the model weight is ready for it to load, if the model weight is not ready after the time, a PULL_FAILED error message will be logged in the worker and it will try to ask for model weight again, but if the failure happens more than max number is allowed, which is defined with environment variable EDT_PULL_FAILED_LIMITS, the new worker will be stopped. environment variables can be customized through –msd-env env_var_name=env_var_value in the training interfaces.

Examples

Model samples are available from https://github.com/IBM/wmla-assets/tree/master/wmla-samples.