My implementation for tensor parallelism from scratch. This requires manual back-probagatation through TP layer.
The follow code snippet are inspired by the amazing Stanford CS336: Language modeling from scratch
from inspect import isfunction
import sys
import torch.multiprocessing as mp
import torch.nn as nn
import math
import torch
import os
from typing import Callable
import torch.nn.functional as F
import torch.distributed as dist
# distributed helper
class DisableDistributed:
"""Context manager that temporarily disables distributed functions (replaces with no-ops)"""
def __enter__(self):
self.old_functions = {}
for name in dir(dist):
value = getattr(dist, name, None)
if isfunction(value):
self.old_functions[name] = value
setattr(dist, name, lambda *args, **kwargs: None)
def __exit__(self, exc_type, exc_value, traceback):
for name in self.old_functions:
setattr(dist, name, self.old_functions[name])
def spawn(func: Callable, world_size: int, *args, **kwargs):
# Note: assume kwargs are in the same order as what main needs
if sys.gettrace():
# If we're being traced, run the function directly, since we can't trace through mp.spawn
with DisableDistributed():
args = (0, world_size,) + args + tuple(kwargs.values())
func(*args)
else:
args = (world_size,) + args + tuple(kwargs.values())
mp.spawn(func, args=args, nprocs=world_size, join=True)
def int_divide(a: int, b: int):
"""Return a / b and throw an error if there's a remainder."""
assert a % b == 0
return a // b
def summarize_tensor(tensor: torch.Tensor) -> str:
return "x".join(map(str, tensor.shape)) + "[" + str(round(tensor.view(-1)[0].item(), 4)) + "...]"
def get_init_params(num_inputs: int, num_outputs: int, rank: int) -> nn.Parameter:
torch.random.manual_seed(0) # For reproducibility
return nn.Parameter(torch.randn(num_inputs, num_outputs, device=get_device(rank)) / math.sqrt(num_outputs))
def render_duration(duration: float) -> str:
if duration < 1e-3:
return f"{duration * 1e6:.2f}us"
if duration < 1:
return f"{duration * 1e3:.2f}ms"
return f"{duration:.2f}s"
def get_device(index: int = 0) -> torch.device:
"""Try to use the GPU if possible, otherwise, use CPU."""
if torch.cuda.is_available():
return torch.device(f"cuda:{index}")
else:
return torch.device("cpu")
# generate dummy data
def generate_sample_data():
batch_size = 2
num_dim = 32
data = torch.randn(batch_size, num_dim)
return data
def tensor_parallelism():
data = generate_sample_data()
spawn(tensor_parallelism_main, world_size=2, data=data, num_layers=4)
def gelu_derivative(x):
# constants
sqrt_2_over_pi = torch.sqrt(torch.tensor(2.0 / torch.pi))
c = 0.044715
u = sqrt_2_over_pi * (x + c * x**3)
tanh_u = torch.tanh(u)
term1 = 0.5 * (1.0 + tanh_u)
term2 = 0.5 * x * (1 - tanh_u**2) * sqrt_2_over_pi * (1 + 3 * c * x**2)
return term1 + term2
def mean_square_derivative(x):
n = x.numel()
return (2.0 / n) * x
def tensor_parallelism_main(rank: int, world_size: int, data: torch.Tensor, num_layers: int):
setup(rank, world_size)
data = data.to(get_device(rank))
batch_size = data.size(0) # @inspect batch_size
num_dim = data.size(1) # @inspect num_dim
local_num_dim = int_divide(num_dim, world_size) # Shard `num_dim` @inspect local_num_dim
# Create model (each rank gets 1/world_size of the parameters)
params = [get_init_params(num_dim, local_num_dim, rank) for i in range(num_layers)]
optimizer = torch.optim.AdamW(params, lr=1e-3) # Each rank has own optimizer state
# Forward pass
x = data
activations_ckpt = []
for i in range(num_layers):
# Compute activations (batch_size x local_num_dim)
x = x @ params[i] # Note: this is only on a slice of the parameters
x = F.gelu(x)
# Allocate memory for activations (world_size x batch_size x local_num_dim)
activations = [torch.empty(batch_size, local_num_dim, device=get_device(rank)) for _ in range(world_size)]
# Send activations via all gather
dist.all_gather(tensor_list=activations, tensor=x, async_op=False)
# Concatenate them to get batch_size x num_dim
x = torch.cat(activations, dim=1)
activations_ckpt.append(x)
print(f"[tensor_parallelism] Rank {rank}: forward pass produced activations {summarize_tensor(x)}", flush=True)
# Backward pass:
loss = x.square().mean() # Loss function is average squared magnitude
bs = x.shape[0]
full_grad_x = mean_square_derivative(x) # [bs, dim_x]
for i in range(num_layers)[::-1]:
grad_wx = full_grad_x * gelu_derivative(activations_ckpt[i]) # [bs, dim_x]
grad_param = grad_wx.T @ activations_ckpt[i] # [dim_x, bs] x [bs, dim_x] --> [dim_x, dim_x]
grad_param = grad_param[:, rank*local_num_dim:(rank+1)*local_num_dim]
grad_x = grad_wx @ params[i] # [bs, dim_x] x [dim_x, dim_x_local] --> [bs, dim_x_local]
params[i].grad = grad_param
full_grad_x = [torch.empty(bs, local_num_dim, device=get_device(rank)) for _ in range(world_size)]
dist.all_gather(tensor_list=full_grad_x, tensor=grad_x, async_op=False)
full_grad_x = torch.cat(full_grad_x, dim=1) # [bs, dim_x]
optimizer.step()
cleanup()
############################################################
def setup(rank: int, world_size: int):
# Specify where master lives (rank 0), used to coordinate (actual data goes through NCCL)
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "15623"
if torch.cuda.is_available():
dist.init_process_group("nccl", rank=rank, world_size=world_size)
else:
dist.init_process_group("gloo", rank=rank, world_size=world_size)
def cleanup():
torch.distributed.destroy_process_group()
if __name__ == "__main__":
main()