My implementation from scratch (using only pytorch primitives) for 3D parallelism which includes data parallelism, tensor parallelism, pipeline parallelism. This requires manual back-probagatation and communicate grad across layers in each rank.
The follow code snippet are inspired by the amazing Stanford CS336: Language modeling from scratch
The challenging part to combine between different parallelism implementation is that tensor, pipeline and data parallelism needs to reside on different process group. For 3D parallelism, if we have 2 dp group –> each group have 2 mp –> each mp has 2 pp group
Most importantly, we need to create a distributed group
from inspect import isfunction
from typing import Callable
import sys
import torch.multiprocessing as mp
import torch.nn as nn
import math
import torch
import time
import os
from typing import List, Callable
import torch.nn.functional as F
import torch.distributed as dist
import torch.distributed.fsdp
class DisableDistributed:
"""Context manager that temporarily disables distributed functions (replaces with no-ops)"""
def __enter__(self):
self.old_functions = {}
for name in dir(dist):
value = getattr(dist, name, None)
if isfunction(value):
self.old_functions[name] = value
setattr(dist, name, lambda *args, **kwargs: None)
def __exit__(self, exc_type, exc_value, traceback):
for name in self.old_functions:
setattr(dist, name, self.old_functions[name])
def spawn(func: Callable, world_size: int, *args, **kwargs):
# Note: assume kwargs are in the same order as what main needs
if sys.gettrace():
# If we're being traced, run the function directly, since we can't trace through mp.spawn
with DisableDistributed():
args = (0, world_size,) + args + tuple(kwargs.values())
func(*args)
else:
args = (world_size,) + args + tuple(kwargs.values())
mp.spawn(func, args=args, nprocs=world_size, join=True)
def int_divide(a: int, b: int):
"""Return a / b and throw an error if there's a remainder."""
assert a % b == 0
return a // b
def summarize_tensor(tensor: torch.Tensor) -> str:
return "x".join(map(str, tensor.shape)) + "[" + str(round(tensor.view(-1)[0].item(), 4)) + "...]"
def get_init_params(num_inputs: int, num_outputs: int, rank: int) -> nn.Parameter:
torch.random.manual_seed(0) # For reproducibility
return nn.Parameter(torch.randn(num_inputs, num_outputs, device=get_device(rank)) / math.sqrt(num_outputs))
def render_duration(duration: float) -> str:
if duration < 1e-3:
return f"{duration * 1e6:.2f}us"
if duration < 1:
return f"{duration * 1e3:.2f}ms"
return f"{duration:.2f}s"
def get_device(index: int = 0) -> torch.device:
"""Try to use the GPU if possible, otherwise, use CPU."""
if torch.cuda.is_available():
return torch.device(f"cuda:{index}")
else:
return torch.device("cpu")
comm_map = {
"pp":{
"send":{0:2,1:3,4:6,5:7},
"recv":{2:0,3:1,6:4,7:5}
},
"tp":{
"gather":{0:"tp_0",1:"tp_0",2:"tp_1",3:"tp_1",4:"tp_2",5:"tp_2",6:"tp_3",7:"tp_3"}
},
"dp":{
"gather":{0:"dp_0",1:"dp_1",2:"dp_2",3:"dp_3",4:"dp_0",5:"dp_1",6:"dp_2",7:"dp_3"}
}
}
def setup(rank: int, world_size: int):
# Specify where master lives (rank 0), used to coordinate (actual data goes through NCCL)
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "15623"
if torch.cuda.is_available():
dist.init_process_group("nccl", rank=rank, world_size=world_size)
else:
dist.init_process_group("gloo", rank=rank, world_size=world_size)
assert world_size == 8, "hard code to support world size = 8"
process_groups = {}
for i in range(8//2):
process_groups[f"tp_{i}"] = dist.new_group(ranks=[i*2, i*2+1])
for i in range(8//2):
process_groups[f"dp_{i}"] = dist.new_group(ranks=[i, i+4])
return process_groups
def main():
dp_tp_pp()
def data_parallelism():
data = generate_sample_data()
spawn(data_parallelism_main, world_size=2, data=data, num_layers=4, num_steps=1)
def generate_sample_data():
batch_size = 4
num_dim = 32
data = torch.randn(batch_size, num_dim)
return data
def data_parallelism_main(rank: int, world_size: int, data: torch.Tensor, num_layers: int, num_steps: int):
setup(rank, world_size)
# Get the slice of data for this rank (in practice, each rank should load only its own data)
batch_size = data.size(0) # @inspect batch_size
num_dim = data.size(1) # @inspect num_dim
local_batch_size = int_divide(batch_size, world_size) # @inspect local_batch_size
start_index = rank * local_batch_size # @inspect start_index
end_index = start_index + local_batch_size # @inspect end_index
data = data[start_index:end_index].to(get_device(rank))
# Create MLP parameters params[0], ..., params[num_layers - 1] (each rank has all parameters)
params = [get_init_params(num_dim, num_dim, rank) for i in range(num_layers)]
optimizer = torch.optim.AdamW(params, lr=1e-3) # Each rank has own optimizer state
for step in range(num_steps):
# Forward pass
x = data
for param in params:
x = x @ param
x = F.gelu(x)
loss = x.square().mean() # Loss function is average squared magnitude
# Backward pass
loss.backward()
# Sync gradients across workers (only difference between standard training and DDP)
for param in params:
dist.all_reduce(tensor=param.grad, op=dist.ReduceOp.AVG, async_op=False)
# Update parameters
optimizer.step()
print(f"[data_parallelism] Rank {rank}: step = {step}, loss = {loss.item()}, params = {[summarize_tensor(params[i]) for i in range(num_layers)]}", flush=True)
cleanup()
def gelu_derivative(x):
# constants
sqrt_2_over_pi = torch.sqrt(torch.tensor(2.0 / torch.pi))
c = 0.044715
u = sqrt_2_over_pi * (x + c * x**3)
tanh_u = torch.tanh(u)
term1 = 0.5 * (1.0 + tanh_u)
term2 = 0.5 * x * (1 - tanh_u**2) * sqrt_2_over_pi * (1 + 3 * c * x**2)
return term1 + term2
def mean_square_derivative(x):
n = x.numel()
return (2.0 / n) * x
def dp_tp_pp():
data = generate_sample_data()
spawn(dp_tp_pp_main, world_size=8, data=data, num_layers=4, num_micro_batches=2)
def dp_tp_pp_main(rank: int, world_size: int, data: torch.Tensor, num_layers: int, num_micro_batches: int, tp: int = 2, pp: int = 2, dp=2):
# doing 2d parallelism first
process_groups = setup(rank, world_size)
# Use all part of data
batch_size = data.size(0) # @inspect batch_size
num_dim = data.size(1) # @inspect num_dim
local_batch_size = int_divide(batch_size, dp) # @inspect local_batch_size
start_index = rank * local_batch_size # @inspect start_index
end_index = start_index + local_batch_size # @inspect end_index
data = data[start_index:end_index].to(get_device(rank))
# Split up layers
local_num_layers = int_divide(num_layers, pp) # @inspect local_num_layers
local_num_dim = int_divide(num_dim, tp)
# Each rank gets a subset of layers
local_params = [get_init_params(num_dim, local_num_dim, rank) for i in range(local_num_layers)] # >> this is actually just partial tensor parallelism since the input is still full dim
# Forward pass
# Break up into micro batches to minimize the bubble
micro_batch_size = int_divide(local_batch_size, num_micro_batches) # @inspect micro_batch_size
if rank == 0:
# The data
micro_batches = data.chunk(chunks=num_micro_batches, dim=0)
else:
# Allocate memory for activations
micro_batches = [torch.empty(micro_batch_size, num_dim, device=get_device(rank)) for _ in range(num_micro_batches)]
tp_pg = process_groups[comm_map["tp"]["gather"][rank]]
dp_pg = process_groups[comm_map["dp"]["gather"][rank]]
src_rank = comm_map["pp"]["recv"][rank] if rank in comm_map["pp"]["recv"] else None
dst_rank = comm_map["pp"]["send"][rank] if rank in comm_map["pp"]["send"] else None
activations_ckpt = []
tp_rank = rank % 2
print(f"[3D parallelism] {rank=}, {src_rank=}, {dst_rank=}, {tp_rank=}", flush=True)
for x in micro_batches:
# Get activations from previous rank
if src_rank is not None:
dist.recv(tensor=x, src=src_rank)
# Compute layers assigned to this rank
for param in local_params:
x = x @ param
x = F.gelu(x)
# Allocate memory for activations (tp x micro_batch_size x local_num_dim)
activations = [torch.empty(micro_batch_size, local_num_dim, device=get_device(rank)) for _ in range(tp)]
# Send activations via all gather
dist.all_gather(tensor_list=activations, tensor=x, async_op=False,group=tp_pg)
# Concatenate them to get batch_size x num_dim
x = torch.cat(activations, dim=1)
activations_ckpt.append(x)
# Send to the next rank
if dst_rank is not None:
print(f"[3D parallelism] Rank {rank}: sending {summarize_tensor(x)} to rank {dst_rank}", flush=True)
dist.send(tensor=x, dst=dst_rank)
# text("Not handled: overlapping communication/computation to eliminate pipeline bubbles")
if dst_rank is None:
loss = x.square().mean() # Loss function is average squared magnitude
full_grad_x = mean_square_derivative(x) # [bs, dim_x]
else:
full_grad_x = torch.empty(micro_batch_size, num_dim, device=get_device(rank)) # [bs, dim_x]
dist.recv(tensor=full_grad_x, src=dst_rank)
for param, activation in list(zip(local_params, activations_ckpt))[::-1]:
grad_wx = full_grad_x * gelu_derivative(activation) # [bs, dim_x]
grad_param = grad_wx.T @ activation # [dim_x, bs] x [bs, dim_x] --> [dim_x, dim_x]
grad_param = grad_param[:, tp_rank*local_num_dim:(tp_rank+1)*local_num_dim]
grad_x = grad_wx @ param # [bs, dim_x] x [dim_x, dim_x_local] --> [bs, dim_x_local]
param.grad = grad_param
# dp gather
dist.all_reduce(tensor=param.grad, op=dist.ReduceOp.AVG, async_op=False, group=dp_pg)
full_grad_x = [torch.empty(micro_batch_size, local_num_dim, device=get_device(rank)) for _ in range(tp)]
dist.all_gather(tensor_list=full_grad_x, tensor=grad_x, async_op=False,group=tp_pg)
full_grad_x = torch.cat(full_grad_x, dim=1) # [bs, dim_x]
if src_rank is not None:
print(f"[3D parallelism] Rank {rank}: sending {summarize_tensor(x)} to rank {src_rank}", flush=True)
dist.send(tensor=full_grad_x, dst=src_rank)
print(f"[3D parallelism] Rank {rank}: finish micro batch", flush=True)
dist.barrier()
cleanup()
def cleanup():
torch.distributed.destroy_process_group()
if __name__ == "__main__":
main()