My implementation for tensor parallelism from scratch. This requires manual back-probagatation through TP layer.

The follow code snippet are inspired by the amazing Stanford CS336: Language modeling from scratch


from inspect import isfunction
import sys
import torch.multiprocessing as mp
import torch.nn as nn
import math

import torch
import os
from typing import Callable
import torch.nn.functional as F
import torch.distributed as dist
# distributed helper
class DisableDistributed:
    """Context manager that temporarily disables distributed functions (replaces with no-ops)"""
    def __enter__(self):
        self.old_functions = {}
        for name in dir(dist):
            value = getattr(dist, name, None)
            if isfunction(value):
                self.old_functions[name] = value
                setattr(dist, name, lambda *args, **kwargs: None)

    def __exit__(self, exc_type, exc_value, traceback):
        for name in self.old_functions:
            setattr(dist, name, self.old_functions[name])


def spawn(func: Callable, world_size: int, *args, **kwargs):
    # Note: assume kwargs are in the same order as what main needs
    if sys.gettrace():
        # If we're being traced, run the function directly, since we can't trace through mp.spawn
        with DisableDistributed():
            args = (0, world_size,) + args + tuple(kwargs.values())
            func(*args)
    else:
        args = (world_size,) + args + tuple(kwargs.values())
        mp.spawn(func, args=args, nprocs=world_size, join=True)


def int_divide(a: int, b: int):
    """Return a / b and throw an error if there's a remainder."""
    assert a % b == 0
    return a // b

def summarize_tensor(tensor: torch.Tensor) -> str:
    return "x".join(map(str, tensor.shape)) + "[" + str(round(tensor.view(-1)[0].item(), 4)) + "...]"


def get_init_params(num_inputs: int, num_outputs: int, rank: int) -> nn.Parameter:
    torch.random.manual_seed(0)  # For reproducibility
    return nn.Parameter(torch.randn(num_inputs, num_outputs, device=get_device(rank)) / math.sqrt(num_outputs))


def render_duration(duration: float) -> str:
    if duration < 1e-3:
        return f"{duration * 1e6:.2f}us"
    if duration < 1:
        return f"{duration * 1e3:.2f}ms"
    return f"{duration:.2f}s"

def get_device(index: int = 0) -> torch.device:
    """Try to use the GPU if possible, otherwise, use CPU."""
    if torch.cuda.is_available():
        return torch.device(f"cuda:{index}")
    else:
        return torch.device("cpu")
# generate dummy data
def generate_sample_data():
    batch_size = 2
    num_dim = 32
    data = torch.randn(batch_size, num_dim)
    return data
def tensor_parallelism():
    data = generate_sample_data()
    spawn(tensor_parallelism_main, world_size=2, data=data, num_layers=4)


def gelu_derivative(x):
    # constants
    sqrt_2_over_pi = torch.sqrt(torch.tensor(2.0 / torch.pi))
    c = 0.044715

    u = sqrt_2_over_pi * (x + c * x**3)
    tanh_u = torch.tanh(u)

    term1 = 0.5 * (1.0 + tanh_u)
    term2 = 0.5 * x * (1 - tanh_u**2) * sqrt_2_over_pi * (1 + 3 * c * x**2)
    return term1 + term2

def mean_square_derivative(x):
    n = x.numel()
    return (2.0 / n) * x

def tensor_parallelism_main(rank: int, world_size: int, data: torch.Tensor, num_layers: int):
    setup(rank, world_size)
    data = data.to(get_device(rank))
    batch_size = data.size(0)  # @inspect batch_size
    num_dim = data.size(1)  # @inspect num_dim
    local_num_dim = int_divide(num_dim, world_size)  # Shard `num_dim`  @inspect local_num_dim
    # Create model (each rank gets 1/world_size of the parameters)
    params = [get_init_params(num_dim, local_num_dim, rank) for i in range(num_layers)]
    optimizer = torch.optim.AdamW(params, lr=1e-3)  # Each rank has own optimizer state
    # Forward pass
    x = data
    activations_ckpt = []
    for i in range(num_layers):
        # Compute activations (batch_size x local_num_dim)
        x = x @ params[i]  # Note: this is only on a slice of the parameters
        x = F.gelu(x)
        # Allocate memory for activations (world_size x batch_size x local_num_dim)
        activations = [torch.empty(batch_size, local_num_dim, device=get_device(rank)) for _ in range(world_size)]
        # Send activations via all gather
        dist.all_gather(tensor_list=activations, tensor=x, async_op=False)
        # Concatenate them to get batch_size x num_dim
        x = torch.cat(activations, dim=1)
        activations_ckpt.append(x)
    print(f"[tensor_parallelism] Rank {rank}: forward pass produced activations {summarize_tensor(x)}", flush=True)
    # Backward pass:
    loss = x.square().mean()  # Loss function is average squared magnitude

    bs = x.shape[0]
    full_grad_x = mean_square_derivative(x) # [bs, dim_x]
    
    for i in range(num_layers)[::-1]:
        grad_wx = full_grad_x * gelu_derivative(activations_ckpt[i]) # [bs, dim_x]
        grad_param = grad_wx.T @ activations_ckpt[i] # [dim_x, bs] x [bs, dim_x] --> [dim_x, dim_x]
        grad_param = grad_param[:, rank*local_num_dim:(rank+1)*local_num_dim]
        grad_x =  grad_wx @ params[i]   # [bs, dim_x] x [dim_x, dim_x_local] --> [bs, dim_x_local]
        params[i].grad = grad_param
         
        full_grad_x = [torch.empty(bs, local_num_dim, device=get_device(rank)) for _ in range(world_size)]
        dist.all_gather(tensor_list=full_grad_x, tensor=grad_x, async_op=False)
        full_grad_x = torch.cat(full_grad_x, dim=1) # [bs, dim_x]

    
    optimizer.step()
    cleanup()

############################################################
def setup(rank: int, world_size: int):
    # Specify where master lives (rank 0), used to coordinate (actual data goes through NCCL)
    os.environ["MASTER_ADDR"] = "localhost"
    os.environ["MASTER_PORT"] = "15623"
    if torch.cuda.is_available():
        dist.init_process_group("nccl", rank=rank, world_size=world_size)
    else:
        dist.init_process_group("gloo", rank=rank, world_size=world_size)

def cleanup():
    torch.distributed.destroy_process_group()

if __name__ == "__main__":
    main()