A Beginner’s Guide to Distributed Training with Mistral Models

Towards AGI
Level Up Coding
Published in
4 min readMar 28, 2024

--

Image Generated with Midjourney

Distributed training is a powerful technique that allows you to train large models across multiple GPUs and machines, significantly reducing training time. PyTorch provides a robust distributed training framework, but implementing it can be complex, especially for beginners.In this article, we’ll take a deep dive into the code you provided, which demonstrates how to set up and use distributed training in PyTorch. We’ll cover the key concepts, functions, and best practices for efficient distributed training.

Initializing Distributed Training

The first step in distributed training is to initialize the distributed environment. This is typically done using the torch.distributed.init_process_group function, which initializes the default distributed process group.

import torch.distributed as dist

dist.init_process_group(backend='nccl', init_method='env://')

Here, we’re using the nccl backend, which is the recommended backend for multi-GPU training. The init_method argument specifies how the initial connection between the processes will be established. The env:// method uses environment variables to set up the connection.

Setting Up Process Groups

In distributed training, processes are organized into groups. The code you provided defines two types of process groups:

  1. Replica group — Each replica group contains one rank from each shard. Gradients are synchronized across replicas during the backward pass.
  2. Shard group — Each shard contains a slice of the model. During the forward pass, each shard processes a subset of the data in parallel.

The our_initialize_model_parallel function sets up these process groups based on the number of replicas and shards specified.

def our_initialize_model_parallel(
_backend: Optional[str] = None,
n_replica: int = 1,
) -> None:
# ...
groups = torch.LongTensor(range(world_size)).reshape(n_replica, shard_size)
# ...
for j in range(shard_size):
ranks = groups[:, j].tolist()
group = torch.distributed.new_group(ranks, backend=_backend)
if j == found[1]:
_REPLICA_GROUP = group
# ...
for i in range(n_replica):
group = torch.distributed.new_group(groups[i, :].tolist(), backend=_backend)
if i == found[0]:
_SHARD_GROUP = group

The function first reshapes the ranks into a 2D tensor, where each row represents a replica and each column represents a shard. It then creates new process groups for each replica and shard using torch.distributed.new_group.

Sharding the Model

With the process groups set up, the next step is to shard the model across the available GPUs. This is typically done by splitting the model into roughly equal-sized shards and assigning each shard to a different GPU.The code you provided doesn’t include the actual model sharding logic, but it does provide helper functions for getting the shard group and shard rank.

@lru_cache()
def get_shard_group() -> ProcessGroup:
"""Get the shard group the caller rank belongs to."""
assert _SHARD_GROUP is not None, "shard group is not initialized"
return _SHARD_GROUP

@lru_cache()
def get_shard_rank() -> int:
return torch.distributed.get_rank(group=get_shard_group())

These functions are used to obtain the shard group and rank for the current process, which are used during the forward and backward passes.

Training Loop

With the model sharded and process groups set up, the training loop can proceed as usual, with a few key modifications:

  1. Each rank only iterates over its shard of the dataset, typically using a DistributedSampler.
  2. In the forward pass, each shard independently processes its subset of the data. No communication happens between shards at this stage.
  3. In the backward pass, gradients are first synchronized across replicas using torch.distributed.all_reduce with the replica process group. This ensures that each shard has the full gradient for its model slice.
  4. The optimizer step is applied independently on each shard, updating the local model parameters.

The code you provided doesn’t include a full training loop, but it does demonstrate how to synchronize gradients across replicas using torch.distributed.all_reduce.

def avg_aggregate(metric: Union[float, int]) -> Union[float, int]:
buffer = torch.tensor([metric], dtype=torch.float32, device="cuda")
dist.all_reduce(buffer, op=dist.ReduceOp.AVG)
return buffer[0].item()

This function takes a metric (e.g., loss or accuracy) and averages it across all replicas using torch.distributed.all_reduce with the AVG reduce operator.

Setting the Device

In multi-GPU training, it’s important to ensure that each process is using the correct GPU. The set_device function in the provided code takes care of this.

def set_device():
# ...
local_rank = int(os.environ["LOCAL_RANK"])
torch.cuda.set_device(local_rank)

The function reads the LOCAL_RANK environment variable, which is typically set by the distributed training launcher (e.g., torch.distributed.launch), and sets the CUDA device accordingly.

Conclusion

Distributed training is a powerful technique for training large models on multiple GPUs and machines. PyTorch provides a flexible and efficient distributed training framework, but it can be complex to set up and use correctly.The code you provided demonstrates some of the key concepts and best practices for distributed training in PyTorch, including:

  1. Initializing the distributed environment
  2. Setting up process groups for replicas and shards
  3. Sharding the model across GPUs
  4. Synchronizing gradients across replicas
  5. Setting the correct CUDA device for each process

By understanding and applying these concepts, you can effectively scale your PyTorch models to take advantage of multiple GPUs and machines, significantly reducing training time and enabling the training of larger and more complex models.

As we finishing up the article, you might be interested in using a variation of AI models, but how to manage so many AI APIs all at once?

That’s where Anakin AI is coming in! Anakin AI is your all in one AI Model aggregator where you can easily use any Langauge Model or Image Generation Model, and build your dream AI App within minutes, not days with a No Code UI!

Anakin AI, the All in one AI Model Aggregator

--

--

AI/ML, Passion about AGI, Computer Science Post Grad, Writing about Life