Share via


Multi-GPU and multi-node workload

You can launch distributed workload across multiple GPUs -- either within a single node or across multiple nodes -- using the Serverless GPU Python API. The API provides a simple, unified interface that abstracts away the details of GPU provisioning, environment setup, and workload distribution. With minimal code changes, you can seamlessly move from single-GPU training to distributed execution across remote GPUs from the same notebook.

Quick start

The serverless GPU API for distributed training is preinstalled in serverless GPU compute environments for Databricks notebooks. We recommend GPU environment 4 and above. To use it for distributed training, import and use the distributed decorator to distribute your training function.

The code snippet below shows the basic usage of @distributed:

# Import the distributed decorator
from serverless_gpu import distributed

# Decorate your training function with @distributed and specify the number of GPUs, the GPU type,
# and whether or not the GPUs are remote
@distributed(gpus=8, gpu_type='A10', remote=True)
def run_train():
    ...

Below is a full example that trains a multilayer perceptron (MLP) model on 8 A10 GPU nodes from a notebook:

  1. Set up your model and define utility functions.

    
    # Define the model
    import os
    import torch
    import torch.distributed as dist
    import torch.nn as nn
    
    def setup():
        dist.init_process_group("nccl")
        torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))
    
    def cleanup():
        dist.destroy_process_group()
    
    class SimpleMLP(nn.Module):
        def __init__(self, input_dim=10, hidden_dim=64, output_dim=1):
            super().__init__()
            self.net = nn.Sequential(
                nn.Linear(input_dim, hidden_dim),
                nn.ReLU(),
                nn.Dropout(0.2),
                nn.Linear(hidden_dim, hidden_dim),
                nn.ReLU(),
                nn.Dropout(0.2),
                nn.Linear(hidden_dim, output_dim)
            )
    
        def forward(self, x):
            return self.net(x)
    
  2. Import the serverless_gpu library and the distributed module.

    import serverless_gpu
    from serverless_gpu import distributed
    
  3. Wrap the model training code in a function and decorate the function with the @distributed decorator.

    @distributed(gpus=8, gpu_type='A10', remote=True)
    def run_train(num_epochs: int, batch_size: int) -> None:
        import mlflow
        import torch.optim as optim
        from torch.nn.parallel import DistributedDataParallel as DDP
        from torch.utils.data import DataLoader, DistributedSampler, TensorDataset
    
        # 1. Set up multi node environment
        setup()
        device = torch.device(f"cuda:{int(os.environ['LOCAL_RANK'])}")
    
        # 2. Apply the Torch distributed data parallel (DDP) library for data-parellel training.
        model = SimpleMLP().to(device)
        model = DDP(model, device_ids=[device])
    
        # 3. Create and load dataset.
        x = torch.randn(5000, 10)
        y = torch.randn(5000, 1)
    
        dataset = TensorDataset(x, y)
        sampler = DistributedSampler(dataset)
        dataloader = DataLoader(dataset, sampler=sampler, batch_size=batch_size)
    
        # 4. Define the training loop.
        optimizer = optim.Adam(model.parameters(), lr=0.001)
        loss_fn = nn.MSELoss()
    
        for epoch in range(num_epochs):
            sampler.set_epoch(epoch)
            model.train()
            total_loss = 0.0
            for step, (xb, yb) in enumerate(dataloader):
                xb, yb = xb.to(device), yb.to(device)
                optimizer.zero_grad()
                loss = loss_fn(model(xb), yb)
                # Log loss to MLflow metric
                mlflow.log_metric("loss", loss.item(), step=step)
    
                loss.backward()
                optimizer.step()
                total_loss += loss.item() * xb.size(0)
    
            mlflow.log_metric("total_loss", total_loss)
            print(f"Total loss for epoch {epoch}: {total_loss}")
    
        cleanup()
    
  4. Execute the distributed training by calling the distributed function with user-defined arguments.

    run_train.distributed(num_epochs=3, batch_size=1)
    
  5. When executed, an MLflow run link is be generated in the notebook cell output. Click the MLflow run link or find it in the Experiment panel to see the run results.

    Output in the notebook cell

Distributed execution details

Serverless GPU API consists of several key components:

  • Compute manager: Handles resource allocation and management
  • Runtime environment: Manages Python environments and dependencies
  • Launcher: Orchestrates job execution and monitoring

When running in distributed mode:

  • The function is serialized and distributed across the specified number of GPUs
  • Each GPU runs a copy of the function with the same parameters
  • The environment is synchronized across all nodes
  • Results are collected and returned from all GPUs

If remote is set to True, the workload is distributed on the remote GPUs. If remote is set to False, the workload is running on the single GPU node connected by the current notebook. If the node has multiple GPU chips, all of them will be utilized.

The API supports popular parallel training libraries such as Distributed Data Parallel (DDP), Fully Sharded Data Parallel (FSDP), DeepSpeed and Ray.

You can find more real distributed training scenarios using the various libraries in notebook examples.

Launch with Ray

The serverless gpu API also supports launching distributed training using Ray using the @ray_launch decorator, which is layered on top of @distributed. Each ray_launch task first bootstraps a torch-distributed rendezvous to decide the Ray head worker and gather IPs. Rank-zero starts ray start --head (with metrics export if enabled), sets RAY_ADDRESS, and runs your decorated function as the Ray driver. Other nodes join via ray start --address and wait until the driver writes a completion marker.

Additional configuration details:

  • To enable Ray system metrics collection on each node, use RayMetricsMonitor with remote=True.
  • Define Ray runtime options (actors, datasets, placement groups, and scheduling) inside your decorated function using standard Ray APIs.
  • Manage cluster-wide controls (GPU count and type, remote vs. local mode, async behavior, and Databricks pool environment variables) outside the function in the decorator arguments or notebook environment.

The example below shows how to use @ray_launch:

from serverless_gpu.ray import ray_launch
@ray_launch(gpus=16, remote=True, gpu_type='A10')
def foo():
    import os
    import ray
    print(ray.state.available_resources_per_node())
    return 1
foo.distributed()

For a complete example, see this notebook, which launches Ray to train a Resnet18 neural network on multiple A10 GPUs.

FAQs

Where should the data loading code be placed?

When using the Serverless GPU API for distributed training, move data loading code inside the @distributed decorator. The dataset size can exceed the maximum size allowed by pickle, so it is recommended to generate the dataset inside the decorator, as shown below:

from serverless_gpu import distributed

# this may cause pickle error
dataset = get_dataset(file_path)
@distributed(gpus=8, remote=True)
def run_train():
  # good practice
  dataset = get_dataset(file_path)
  ....

How to use reserved pool?

If reserved GPU pool is availale (please check with your admin) in your workspace and you specify remote to True in the @distributed decorator, the workload will be launched on the reserved GPU pool by default. If you want it to use the on-demand GPU pool, please set the environment variable DATABRICKS_USE_RESERVED_GPU_POOL to True before calling the distributed function.

Learn more

For the API reference, refer to the Serverless GPU Python API documentation.