My implementation from scratch (using only pytorch primitives) for 3D parallelism which includes data parallelism, tensor parallelism, pipeline parallelism. This requires manual back-probagatation and communicate grad across layers in each rank.

The follow code snippet are inspired by the amazing Stanford CS336: Language modeling from scratch


The challenging part to combine between different parallelism implementation is that tensor, pipeline and data parallelism needs to reside on different process group. For 3D parallelism, if we have 2 dp group –> each group have 2 mp –> each mp has 2 pp group

Most importantly, we need to create a distributed group

from inspect import isfunction
from typing import Callable
import sys
import torch.multiprocessing as mp
import torch.nn as nn
import math

import torch
import time
import os
from typing import List, Callable
import torch.nn.functional as F
import torch.distributed as dist
import torch.distributed.fsdp

class DisableDistributed:
    """Context manager that temporarily disables distributed functions (replaces with no-ops)"""
    def __enter__(self):
        self.old_functions = {}
        for name in dir(dist):
            value = getattr(dist, name, None)
            if isfunction(value):
                self.old_functions[name] = value
                setattr(dist, name, lambda *args, **kwargs: None)

    def __exit__(self, exc_type, exc_value, traceback):
        for name in self.old_functions:
            setattr(dist, name, self.old_functions[name])


def spawn(func: Callable, world_size: int, *args, **kwargs):
    # Note: assume kwargs are in the same order as what main needs
    if sys.gettrace():
        # If we're being traced, run the function directly, since we can't trace through mp.spawn
        with DisableDistributed():
            args = (0, world_size,) + args + tuple(kwargs.values())
            func(*args)
    else:
        args = (world_size,) + args + tuple(kwargs.values())
        mp.spawn(func, args=args, nprocs=world_size, join=True)


def int_divide(a: int, b: int):
    """Return a / b and throw an error if there's a remainder."""
    assert a % b == 0
    return a // b

def summarize_tensor(tensor: torch.Tensor) -> str:
    return "x".join(map(str, tensor.shape)) + "[" + str(round(tensor.view(-1)[0].item(), 4)) + "...]"


def get_init_params(num_inputs: int, num_outputs: int, rank: int) -> nn.Parameter:
    torch.random.manual_seed(0)  # For reproducibility
    return nn.Parameter(torch.randn(num_inputs, num_outputs, device=get_device(rank)) / math.sqrt(num_outputs))


def render_duration(duration: float) -> str:
    if duration < 1e-3:
        return f"{duration * 1e6:.2f}us"
    if duration < 1:
        return f"{duration * 1e3:.2f}ms"
    return f"{duration:.2f}s"

def get_device(index: int = 0) -> torch.device:
    """Try to use the GPU if possible, otherwise, use CPU."""
    if torch.cuda.is_available():
        return torch.device(f"cuda:{index}")
    else:
        return torch.device("cpu")
comm_map = {
    "pp":{
        "send":{0:2,1:3,4:6,5:7},
        "recv":{2:0,3:1,6:4,7:5}
    },
    "tp":{
        "gather":{0:"tp_0",1:"tp_0",2:"tp_1",3:"tp_1",4:"tp_2",5:"tp_2",6:"tp_3",7:"tp_3"}
    },
    "dp":{
        "gather":{0:"dp_0",1:"dp_1",2:"dp_2",3:"dp_3",4:"dp_0",5:"dp_1",6:"dp_2",7:"dp_3"}
    }
}

def setup(rank: int, world_size: int):
    # Specify where master lives (rank 0), used to coordinate (actual data goes through NCCL)
    os.environ["MASTER_ADDR"] = "localhost"
    os.environ["MASTER_PORT"] = "15623"
    if torch.cuda.is_available():
        dist.init_process_group("nccl", rank=rank, world_size=world_size)
    else:
        dist.init_process_group("gloo", rank=rank, world_size=world_size)

    assert world_size == 8, "hard code to support world size = 8"

    process_groups = {}
    for i in range(8//2):
        process_groups[f"tp_{i}"] = dist.new_group(ranks=[i*2, i*2+1])

    for i in range(8//2):
        process_groups[f"dp_{i}"] = dist.new_group(ranks=[i, i+4])
    
    return process_groups
def main():
    
    dp_tp_pp()

def data_parallelism():
    data = generate_sample_data()
    spawn(data_parallelism_main, world_size=2, data=data, num_layers=4, num_steps=1)

def generate_sample_data():
    batch_size = 4
    num_dim = 32
    data = torch.randn(batch_size, num_dim)
    return data

def data_parallelism_main(rank: int, world_size: int, data: torch.Tensor, num_layers: int, num_steps: int):
    setup(rank, world_size)
    # Get the slice of data for this rank (in practice, each rank should load only its own data)
    batch_size = data.size(0)  # @inspect batch_size
    num_dim = data.size(1)  # @inspect num_dim
    local_batch_size = int_divide(batch_size, world_size)  # @inspect local_batch_size
    start_index = rank * local_batch_size  # @inspect start_index
    end_index = start_index + local_batch_size  # @inspect end_index
    data = data[start_index:end_index].to(get_device(rank))
    # Create MLP parameters params[0], ..., params[num_layers - 1] (each rank has all parameters)
    params = [get_init_params(num_dim, num_dim, rank) for i in range(num_layers)]
    optimizer = torch.optim.AdamW(params, lr=1e-3)  # Each rank has own optimizer state
    for step in range(num_steps):
        # Forward pass
        x = data
        for param in params:
            x = x @ param
            x = F.gelu(x)
        loss = x.square().mean()  # Loss function is average squared magnitude
        # Backward pass
        loss.backward()
        # Sync gradients across workers (only difference between standard training and DDP)
        for param in params:
            dist.all_reduce(tensor=param.grad, op=dist.ReduceOp.AVG, async_op=False)
        # Update parameters
        optimizer.step()
        print(f"[data_parallelism] Rank {rank}: step = {step}, loss = {loss.item()}, params = {[summarize_tensor(params[i]) for i in range(num_layers)]}", flush=True)
    cleanup()

def gelu_derivative(x):
    # constants
    sqrt_2_over_pi = torch.sqrt(torch.tensor(2.0 / torch.pi))
    c = 0.044715

    u = sqrt_2_over_pi * (x + c * x**3)
    tanh_u = torch.tanh(u)

    term1 = 0.5 * (1.0 + tanh_u)
    term2 = 0.5 * x * (1 - tanh_u**2) * sqrt_2_over_pi * (1 + 3 * c * x**2)
    return term1 + term2

def mean_square_derivative(x):
    n = x.numel()
    return (2.0 / n) * x

def dp_tp_pp():
    data = generate_sample_data()
    spawn(dp_tp_pp_main, world_size=8, data=data, num_layers=4, num_micro_batches=2)

def dp_tp_pp_main(rank: int, world_size: int, data: torch.Tensor, num_layers: int, num_micro_batches: int, tp: int = 2, pp: int = 2, dp=2):
    # doing 2d parallelism first
    process_groups = setup(rank, world_size)
    # Use all part of data
    batch_size = data.size(0)  # @inspect batch_size
    num_dim = data.size(1)  # @inspect num_dim
    local_batch_size = int_divide(batch_size, dp)  # @inspect local_batch_size
    start_index = rank * local_batch_size  # @inspect start_index
    end_index = start_index + local_batch_size  # @inspect end_index
    data = data[start_index:end_index].to(get_device(rank))
    # Split up layers
    local_num_layers = int_divide(num_layers, pp)  # @inspect local_num_layers
    local_num_dim = int_divide(num_dim, tp)
    # Each rank gets a subset of layers
    local_params = [get_init_params(num_dim, local_num_dim, rank) for i in range(local_num_layers)] # >> this is actually just partial tensor parallelism since the input is still full dim
    # Forward pass
    # Break up into micro batches to minimize the bubble
    micro_batch_size = int_divide(local_batch_size, num_micro_batches)  # @inspect micro_batch_size
    if rank == 0:
        # The data
        micro_batches = data.chunk(chunks=num_micro_batches, dim=0)
    else:
        # Allocate memory for activations
        micro_batches = [torch.empty(micro_batch_size, num_dim, device=get_device(rank)) for _ in range(num_micro_batches)]

    tp_pg = process_groups[comm_map["tp"]["gather"][rank]]
    dp_pg = process_groups[comm_map["dp"]["gather"][rank]]
    src_rank = comm_map["pp"]["recv"][rank] if rank in comm_map["pp"]["recv"] else None
    dst_rank = comm_map["pp"]["send"][rank] if rank in comm_map["pp"]["send"] else None
    activations_ckpt = []
    tp_rank = rank % 2
    print(f"[3D parallelism] {rank=}, {src_rank=}, {dst_rank=}, {tp_rank=}", flush=True)
    for x in micro_batches:
        # Get activations from previous rank
        if src_rank is not None:
            dist.recv(tensor=x, src=src_rank)
        # Compute layers assigned to this rank
        for param in local_params:
            x = x @ param
            x = F.gelu(x)
            #  Allocate memory for activations (tp x micro_batch_size x local_num_dim)
            activations = [torch.empty(micro_batch_size, local_num_dim, device=get_device(rank)) for _ in range(tp)]
            # Send activations via all gather
            dist.all_gather(tensor_list=activations, tensor=x, async_op=False,group=tp_pg)
            # Concatenate them to get batch_size x num_dim
            x = torch.cat(activations, dim=1)
            activations_ckpt.append(x)
        # Send to the next rank
        if dst_rank is not None:
            print(f"[3D parallelism] Rank {rank}: sending {summarize_tensor(x)} to rank {dst_rank}", flush=True)
            dist.send(tensor=x, dst=dst_rank)
        
        # text("Not handled: overlapping communication/computation to eliminate pipeline bubbles")
        if dst_rank is None:
            loss = x.square().mean()  # Loss function is average squared magnitude
            full_grad_x = mean_square_derivative(x) # [bs, dim_x]
        else:
            full_grad_x = torch.empty(micro_batch_size, num_dim, device=get_device(rank)) # [bs, dim_x]
            dist.recv(tensor=full_grad_x, src=dst_rank)
        
        for param, activation in list(zip(local_params, activations_ckpt))[::-1]:
            grad_wx = full_grad_x * gelu_derivative(activation) # [bs, dim_x]
            grad_param = grad_wx.T @ activation # [dim_x, bs] x [bs, dim_x] --> [dim_x, dim_x]
            grad_param = grad_param[:, tp_rank*local_num_dim:(tp_rank+1)*local_num_dim]
            grad_x =  grad_wx @ param   # [bs, dim_x] x [dim_x, dim_x_local] --> [bs, dim_x_local]
            param.grad = grad_param

            # dp gather
            dist.all_reduce(tensor=param.grad, op=dist.ReduceOp.AVG, async_op=False, group=dp_pg)
            
            full_grad_x = [torch.empty(micro_batch_size, local_num_dim, device=get_device(rank)) for _ in range(tp)]
            dist.all_gather(tensor_list=full_grad_x, tensor=grad_x, async_op=False,group=tp_pg)
            full_grad_x = torch.cat(full_grad_x, dim=1) # [bs, dim_x]

        if src_rank is not None:
            print(f"[3D parallelism] Rank {rank}: sending {summarize_tensor(x)} to rank {src_rank}", flush=True)
            dist.send(tensor=full_grad_x, dst=src_rank)
        
        print(f"[3D parallelism] Rank {rank}: finish micro batch", flush=True)
        dist.barrier()

    cleanup()

def cleanup():
    torch.distributed.destroy_process_group()

if __name__ == "__main__":
    main()