Note
Access to this page requires authorization. You can try signing in or changing directories.
Access to this page requires authorization. You can try changing directories.
You can launch distributed workload across multiple GPUs -- either within a single node or across multiple nodes -- using the Serverless GPU Python API. The API provides a simple, unified interface that abstracts away the details of GPU provisioning, environment setup, and workload distribution. With minimal code changes, you can seamlessly move from single-GPU training to distributed execution across remote GPUs from the same notebook.
Quick start
The serverless GPU API for distributed training is preinstalled in serverless GPU compute
environments for Databricks notebooks. We recommend GPU environment 4 and above. To use it for distributed training, import and use the
distributed decorator to distribute your training function.
The code snippet below shows the basic usage of @distributed:
# Import the distributed decorator
from serverless_gpu import distributed
# Decorate your training function with @distributed and specify the number of GPUs, the GPU type,
# and whether or not the GPUs are remote
@distributed(gpus=8, gpu_type='A10', remote=True)
def run_train():
...
Below is a full example that trains a multilayer perceptron (MLP) model on 8 A10 GPU nodes from a notebook:
Set up your model and define utility functions.
# Define the model import os import torch import torch.distributed as dist import torch.nn as nn def setup(): dist.init_process_group("nccl") torch.cuda.set_device(int(os.environ["LOCAL_RANK"])) def cleanup(): dist.destroy_process_group() class SimpleMLP(nn.Module): def __init__(self, input_dim=10, hidden_dim=64, output_dim=1): super().__init__() self.net = nn.Sequential( nn.Linear(input_dim, hidden_dim), nn.ReLU(), nn.Dropout(0.2), nn.Linear(hidden_dim, hidden_dim), nn.ReLU(), nn.Dropout(0.2), nn.Linear(hidden_dim, output_dim) ) def forward(self, x): return self.net(x)Import the serverless_gpu library and the distributed module.
import serverless_gpu from serverless_gpu import distributedWrap the model training code in a function and decorate the function with the
@distributeddecorator.@distributed(gpus=8, gpu_type='A10', remote=True) def run_train(num_epochs: int, batch_size: int) -> None: import mlflow import torch.optim as optim from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.data import DataLoader, DistributedSampler, TensorDataset # 1. Set up multi node environment setup() device = torch.device(f"cuda:{int(os.environ['LOCAL_RANK'])}") # 2. Apply the Torch distributed data parallel (DDP) library for data-parellel training. model = SimpleMLP().to(device) model = DDP(model, device_ids=[device]) # 3. Create and load dataset. x = torch.randn(5000, 10) y = torch.randn(5000, 1) dataset = TensorDataset(x, y) sampler = DistributedSampler(dataset) dataloader = DataLoader(dataset, sampler=sampler, batch_size=batch_size) # 4. Define the training loop. optimizer = optim.Adam(model.parameters(), lr=0.001) loss_fn = nn.MSELoss() for epoch in range(num_epochs): sampler.set_epoch(epoch) model.train() total_loss = 0.0 for step, (xb, yb) in enumerate(dataloader): xb, yb = xb.to(device), yb.to(device) optimizer.zero_grad() loss = loss_fn(model(xb), yb) # Log loss to MLflow metric mlflow.log_metric("loss", loss.item(), step=step) loss.backward() optimizer.step() total_loss += loss.item() * xb.size(0) mlflow.log_metric("total_loss", total_loss) print(f"Total loss for epoch {epoch}: {total_loss}") cleanup()Execute the distributed training by calling the distributed function with user-defined arguments.
run_train.distributed(num_epochs=3, batch_size=1)When executed, an MLflow run link is be generated in the notebook cell output. Click the MLflow run link or find it in the Experiment panel to see the run results.

Distributed execution details
Serverless GPU API consists of several key components:
- Compute manager: Handles resource allocation and management
- Runtime environment: Manages Python environments and dependencies
- Launcher: Orchestrates job execution and monitoring
When running in distributed mode:
- The function is serialized and distributed across the specified number of GPUs
- Each GPU runs a copy of the function with the same parameters
- The environment is synchronized across all nodes
- Results are collected and returned from all GPUs
If remote is set to True, the workload is distributed on the remote GPUs. If remote is set to
False, the workload is running on the single GPU node connected by the current notebook. If the
node has multiple GPU chips, all of them will be utilized.
The API supports popular parallel training libraries such as Distributed Data Parallel (DDP), Fully Sharded Data Parallel (FSDP), DeepSpeed and Ray.
You can find more real distributed training scenarios using the various libraries in notebook examples.
Launch with Ray
The serverless gpu API also supports launching distributed training using Ray using the @ray_launch
decorator, which is layered on top of @distributed.
Each ray_launch task first bootstraps a torch-distributed rendezvous to decide the Ray head worker
and gather IPs. Rank-zero starts ray start --head (with metrics export if enabled), sets
RAY_ADDRESS, and runs your decorated function as the Ray driver. Other nodes join via
ray start --address and wait until the driver writes a completion marker.
Additional configuration details:
- To enable Ray system metrics collection on each node, use
RayMetricsMonitorwithremote=True. - Define Ray runtime options (actors, datasets, placement groups, and scheduling) inside your decorated function using standard Ray APIs.
- Manage cluster-wide controls (GPU count and type, remote vs. local mode, async behavior, and Databricks pool environment variables) outside the function in the decorator arguments or notebook environment.
The example below shows how to use @ray_launch:
from serverless_gpu.ray import ray_launch
@ray_launch(gpus=16, remote=True, gpu_type='A10')
def foo():
import os
import ray
print(ray.state.available_resources_per_node())
return 1
foo.distributed()
For a complete example, see this notebook, which launches Ray to train a Resnet18 neural network on multiple A10 GPUs.
FAQs
Where should the data loading code be placed?
When using the Serverless GPU API for distributed training, move data loading code inside the @distributed decorator. The dataset size can exceed the maximum size allowed by pickle, so it is recommended to generate the dataset inside the decorator, as shown below:
from serverless_gpu import distributed
# this may cause pickle error
dataset = get_dataset(file_path)
@distributed(gpus=8, remote=True)
def run_train():
# good practice
dataset = get_dataset(file_path)
....
How to use reserved pool?
If reserved GPU pool is availale (please check with your admin) in your workspace and you specify
remote to True in the @distributed decorator, the workload will be launched on the reserved GPU
pool by default. If you want it to use the on-demand GPU pool, please set the environment variable
DATABRICKS_USE_RESERVED_GPU_POOL to True before calling the distributed function.
Learn more
For the API reference, refer to the Serverless GPU Python API documentation.