My implementation for pipeline parallelism from scratch. This requires manual back-probagatation and communicate grad across layers in each rank.
The follow code snippet are inspired by the amazing Stanford CS336: Language modeling from scratch
from inspect import isfunction
from typing import Callable
import sys
import torch.multiprocessing as mp
import torch.nn as nn
import math
import torch
import time
import os
from typing import List, Callable
import torch.nn.functional as F
import torch.distributed as dist
import torch.distributed.fsdp
** Helper function for parallelism
class DisableDistributed:
"""Context manager that temporarily disables distributed functions (replaces with no-ops)"""
def __enter__(self):
self.old_functions = {}
for name in dir(dist):
value = getattr(dist, name, None)
if isfunction(value):
self.old_functions[name] = value
setattr(dist, name, lambda *args, **kwargs: None)
def __exit__(self, exc_type, exc_value, traceback):
for name in self.old_functions:
setattr(dist, name, self.old_functions[name])
def spawn(func: Callable, world_size: int, *args, **kwargs):
# Note: assume kwargs are in the same order as what main needs
if sys.gettrace():
# If we're being traced, run the function directly, since we can't trace through mp.spawn
with DisableDistributed():
args = (0, world_size,) + args + tuple(kwargs.values())
func(*args)
else:
args = (world_size,) + args + tuple(kwargs.values())
mp.spawn(func, args=args, nprocs=world_size, join=True)
def int_divide(a: int, b: int):
"""Return a / b and throw an error if there's a remainder."""
assert a % b == 0
return a // b
def summarize_tensor(tensor: torch.Tensor) -> str:
return "x".join(map(str, tensor.shape)) + "[" + str(round(tensor.view(-1)[0].item(), 4)) + "...]"
def get_init_params(num_inputs: int, num_outputs: int, rank: int) -> nn.Parameter:
torch.random.manual_seed(0) # For reproducibility
return nn.Parameter(torch.randn(num_inputs, num_outputs, device=get_device(rank)) / math.sqrt(num_outputs))
def render_duration(duration: float) -> str:
if duration < 1e-3:
return f"{duration * 1e6:.2f}us"
if duration < 1:
return f"{duration * 1e3:.2f}ms"
return f"{duration:.2f}s"
def get_device(index: int = 0) -> torch.device:
"""Try to use the GPU if possible, otherwise, use CPU."""
if torch.cuda.is_available():
return torch.device(f"cuda:{index}")
else:
return torch.device("cpu")
** Data preparation function
def generate_sample_data():
batch_size = 2
num_dim = 32
data = torch.randn(batch_size, num_dim)
return data
** Derivative
def gelu_derivative(x):
# constants
sqrt_2_over_pi = torch.sqrt(torch.tensor(2.0 / torch.pi))
c = 0.044715
u = sqrt_2_over_pi * (x + c * x**3)
tanh_u = torch.tanh(u)
term1 = 0.5 * (1.0 + tanh_u)
term2 = 0.5 * x * (1 - tanh_u**2) * sqrt_2_over_pi * (1 + 3 * c * x**2)
return term1 + term2
def mean_square_derivative(x):
n = x.numel()
return (2.0 / n) * x
def main():
pipeline_parallelism() # Cut up along the depth dimension
def pipeline_parallelism():
data = generate_sample_data()
spawn(pipeline_parallelism_main, world_size=2, data=data, num_layers=4, num_micro_batches=2)
def pipeline_parallelism_main(rank: int, world_size: int, data: torch.Tensor, num_layers: int, num_micro_batches: int):
setup(rank, world_size)
# Use all the data
data = data.to(get_device(rank))
batch_size = data.size(0) # @inspect batch_size
num_dim = data.size(1) # @inspect num_dim
# Split up layers
local_num_layers = int_divide(num_layers, world_size) # @inspect local_num_layers
# Each rank gets a subset of layers
local_params = [get_init_params(num_dim, num_dim, rank) for i in range(local_num_layers)]
# Forward pass
# Break up into micro batches to minimize the bubble
micro_batch_size = int_divide(batch_size, num_micro_batches) # @inspect micro_batch_size
if rank == 0:
# The data
micro_batches = data.chunk(chunks=num_micro_batches, dim=0)
else:
# Allocate memory for activations
micro_batches = [torch.empty(micro_batch_size, num_dim, device=get_device(rank)) for _ in range(num_micro_batches)]
activations_ckpt = []
for x in micro_batches:
# Get activations from previous rank
if rank - 1 >= 0:
dist.recv(tensor=x, src=rank - 1)
# Compute layers assigned to this rank
for param in local_params:
x = x @ param
x = F.gelu(x)
activations_ckpt.append(x)
# Send to the next rank
if rank + 1 < world_size:
print(f"[pipeline_parallelism] Rank {rank}: sending {summarize_tensor(x)} to rank {rank + 1}", flush=True)
dist.send(tensor=x, dst=rank + 1)
# backward pass pipeline parallelism
if rank == world_size - 1:
loss = x.square().mean() # Loss function is average squared magnitude
full_grad_x = mean_square_derivative(x) # [bs, dim_x]
else:
full_grad_x = torch.empty(micro_batch_size, num_dim, device=get_device(rank)) # [bs, dim_x]
dist.recv(tensor=full_grad_x, src=rank + 1)
for param, activation in list(zip(local_params, activations_ckpt))[::-1]:
grad_wx = full_grad_x * gelu_derivative(activation) # [bs, dim_x]
grad_param = grad_wx.T @ activation # [dim_x, bs] x [bs, dim_x] --> [dim_x, dim_x]
full_grad_x = grad_wx @ param # [bs, dim_x] x [dim_x, dim_x] --> [bs, dim_x]
param.grad = grad_param
if rank > 0:
dist.send(tensor=full_grad_x, dst=rank - 1)
cleanup()
############################################################
def setup(rank: int, world_size: int):
# Specify where master lives (rank 0), used to coordinate (actual data goes through NCCL)
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "15623"
if torch.cuda.is_available():
dist.init_process_group("nccl", rank=rank, world_size=world_size)
else:
dist.init_process_group("gloo", rank=rank, world_size=world_size)
def cleanup():
torch.distributed.destroy_process_group()
if __name__ == "__main__":
main()