My implementation for pipeline parallelism from scratch. This requires manual back-probagatation and communicate grad across layers in each rank.

The follow code snippet are inspired by the amazing Stanford CS336: Language modeling from scratch


from inspect import isfunction
from typing import Callable
import sys
import torch.multiprocessing as mp
import torch.nn as nn
import math

import torch
import time
import os
from typing import List, Callable
import torch.nn.functional as F
import torch.distributed as dist
import torch.distributed.fsdp

** Helper function for parallelism

class DisableDistributed:
    """Context manager that temporarily disables distributed functions (replaces with no-ops)"""
    def __enter__(self):
        self.old_functions = {}
        for name in dir(dist):
            value = getattr(dist, name, None)
            if isfunction(value):
                self.old_functions[name] = value
                setattr(dist, name, lambda *args, **kwargs: None)

    def __exit__(self, exc_type, exc_value, traceback):
        for name in self.old_functions:
            setattr(dist, name, self.old_functions[name])


def spawn(func: Callable, world_size: int, *args, **kwargs):
    # Note: assume kwargs are in the same order as what main needs
    if sys.gettrace():
        # If we're being traced, run the function directly, since we can't trace through mp.spawn
        with DisableDistributed():
            args = (0, world_size,) + args + tuple(kwargs.values())
            func(*args)
    else:
        args = (world_size,) + args + tuple(kwargs.values())
        mp.spawn(func, args=args, nprocs=world_size, join=True)


def int_divide(a: int, b: int):
    """Return a / b and throw an error if there's a remainder."""
    assert a % b == 0
    return a // b

def summarize_tensor(tensor: torch.Tensor) -> str:
    return "x".join(map(str, tensor.shape)) + "[" + str(round(tensor.view(-1)[0].item(), 4)) + "...]"


def get_init_params(num_inputs: int, num_outputs: int, rank: int) -> nn.Parameter:
    torch.random.manual_seed(0)  # For reproducibility
    return nn.Parameter(torch.randn(num_inputs, num_outputs, device=get_device(rank)) / math.sqrt(num_outputs))


def render_duration(duration: float) -> str:
    if duration < 1e-3:
        return f"{duration * 1e6:.2f}us"
    if duration < 1:
        return f"{duration * 1e3:.2f}ms"
    return f"{duration:.2f}s"

def get_device(index: int = 0) -> torch.device:
    """Try to use the GPU if possible, otherwise, use CPU."""
    if torch.cuda.is_available():
        return torch.device(f"cuda:{index}")
    else:
        return torch.device("cpu")

** Data preparation function

def generate_sample_data():
    batch_size = 2
    num_dim = 32
    data = torch.randn(batch_size, num_dim)
    return data

** Derivative

def gelu_derivative(x):
    # constants
    sqrt_2_over_pi = torch.sqrt(torch.tensor(2.0 / torch.pi))
    c = 0.044715

    u = sqrt_2_over_pi * (x + c * x**3)
    tanh_u = torch.tanh(u)

    term1 = 0.5 * (1.0 + tanh_u)
    term2 = 0.5 * x * (1 - tanh_u**2) * sqrt_2_over_pi * (1 + 3 * c * x**2)
    return term1 + term2

def mean_square_derivative(x):
    n = x.numel()
    return (2.0 / n) * x
def main():
    pipeline_parallelism()     # Cut up along the depth dimension

def pipeline_parallelism():
    data = generate_sample_data()
    spawn(pipeline_parallelism_main, world_size=2, data=data, num_layers=4, num_micro_batches=2)

def pipeline_parallelism_main(rank: int, world_size: int, data: torch.Tensor, num_layers: int, num_micro_batches: int):
    setup(rank, world_size)
    # Use all the data
    data = data.to(get_device(rank))
    batch_size = data.size(0)  # @inspect batch_size
    num_dim = data.size(1)  # @inspect num_dim
    # Split up layers
    local_num_layers = int_divide(num_layers, world_size)  # @inspect local_num_layers
    # Each rank gets a subset of layers
    local_params = [get_init_params(num_dim, num_dim, rank) for i in range(local_num_layers)]
    # Forward pass
    # Break up into micro batches to minimize the bubble
    micro_batch_size = int_divide(batch_size, num_micro_batches)  # @inspect micro_batch_size
    if rank == 0:
        # The data
        micro_batches = data.chunk(chunks=num_micro_batches, dim=0)
    else:
        # Allocate memory for activations
        micro_batches = [torch.empty(micro_batch_size, num_dim, device=get_device(rank)) for _ in range(num_micro_batches)]

    activations_ckpt = []
    for x in micro_batches:
        # Get activations from previous rank
        if rank - 1 >= 0:
            dist.recv(tensor=x, src=rank - 1)
        # Compute layers assigned to this rank
        for param in local_params:
            x = x @ param
            x = F.gelu(x)
            activations_ckpt.append(x)
        # Send to the next rank
        if rank + 1 < world_size:
            print(f"[pipeline_parallelism] Rank {rank}: sending {summarize_tensor(x)} to rank {rank + 1}", flush=True)
            dist.send(tensor=x, dst=rank + 1)

        # backward pass pipeline parallelism
        if rank == world_size - 1:
            loss = x.square().mean()  # Loss function is average squared magnitude
            full_grad_x = mean_square_derivative(x) # [bs, dim_x]
        else:
            full_grad_x = torch.empty(micro_batch_size, num_dim, device=get_device(rank)) # [bs, dim_x]
            dist.recv(tensor=full_grad_x, src=rank + 1)
        
        for param, activation in list(zip(local_params, activations_ckpt))[::-1]:
            grad_wx = full_grad_x * gelu_derivative(activation) # [bs, dim_x]
            grad_param = grad_wx.T @ activation # [dim_x, bs] x [bs, dim_x] --> [dim_x, dim_x]
            full_grad_x =  grad_wx @ param   # [bs, dim_x] x [dim_x, dim_x] --> [bs, dim_x]
            param.grad = grad_param

        if rank > 0:
            dist.send(tensor=full_grad_x, dst=rank - 1)

    cleanup()

############################################################
def setup(rank: int, world_size: int):
    # Specify where master lives (rank 0), used to coordinate (actual data goes through NCCL)
    os.environ["MASTER_ADDR"] = "localhost"
    os.environ["MASTER_PORT"] = "15623"
    if torch.cuda.is_available():
        dist.init_process_group("nccl", rank=rank, world_size=world_size)
    else:
        dist.init_process_group("gloo", rank=rank, world_size=world_size)

def cleanup():
    torch.distributed.destroy_process_group()

if __name__ == "__main__":
    main()