CPU implementation This is a fun project on my side to create a MoE model
import math
import inspect
from dataclasses import dataclass
from contextlib import nullcontext
import torch
import torch.nn as nn
from torch.nn import functional as F
def compute_aux_loss(self, expert_probs, indices):
pass
def compute_router_z_loss(self, logits):
pass
def get_capacity(self, tokens_per_batch):
pass
class LayerNorm(nn.Module):
def __init__(self, ndim, bias):
super().__init__()
self.ndim = ndim
self.bias = bias
self.weight = nn.Parameter(torch.ones(ndim))
self.bias = nn.Parameter(torch.zeros(ndim)) if bias else None
def forward(self, input):
return F.layer_norm(input, normalized_shape=self.weight.shape, weight=self.weight, bias=self.bias, eps=1e-5)
class CausalSelfAttention(nn.Module):
def __init__(self, config):
super().__init__()
assert config.n_embd % config.n_head == 0
self.c_attn = nn.Linear(config.n_embd, config.n_embd*3, bias=config.bias)
self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)
self.attn_dropout = nn.Dropout(config.dropout)
self.residual_dropout = nn.Dropout(config.dropout)
# book keeping
self.n_embd = config.n_embd
self.n_head = config.n_head
# this is to set attention mask
self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size))
.view(1, 1, config.block_size, config.block_size))
def forward(self, x):
B, T, C = x.shape
q, k, v = self.c_attn(x).split(self.n_embd, dim = -1)
q = q.view(B, T, self.n_head, C // self.n_head).transpose(1,2) # the transpose here is for the ease of softmax computation
k = k.view(B, T, self.n_head, C // self.n_head).transpose(1,2)
v = v.view(B, T, self.n_head, C // self.n_head).transpose(1,2)
att = torch.einsum('bhtd,bhzd->bhtz', q,k)/math.sqrt(k.size(-1))
att = att.masked_fill(self.bias[:,:,:T,:T] == 0, float('-inf'))
att = torch.softmax(att, dim = -1)
att = self.attn_dropout(att)
y = torch.einsum('bhtz,bhzd->bhtd', att, v)
y = y.transpose(1,2).contiguous().view(B,T,C)
y = self.residual_dropout(self.c_project(y))
return y
class Router(nn.Module):
def __init__(self, config):
super().__init__()
# book keeping
self.top_k = config.top_k
self.n_exp = config.n_exp
assert self.top_k >= 1 and self.n_exp >= self.top_k
self.use_noisy_top_k = config.use_noisy_top_k
self.train_capacity = config.train_capacity
self.eval_capacity = config.eval_capacity
self.min_capacity = config.min_capacity
self.router_use_full_prec = config.router_use_full_prec
# aux loss for router
self.use_aux_loss = config.use_aux_loss
self.use_router_z_loss = config.use_router_z_loss
# param
self.w_g = nn.Linear(config.n_embd, config.n_exp, bias=False) # no bias --> how ever we can apply adhoc bias modification during training for load balancing
self.w_noise = nn.Linear(config.n_embd, config.n_exp, bias=False)
def forward(self, x): # x [bs, seq_len, dim]
B, T, _ = x.size()
num_tokens = B*T
logits = self.w_g(x) # [B, T, n_exp]
if self.use_noisy_top_k:
noise = F.softplus(self.w_noise(x)) # compute noise std
noise *= torch.randn_like(noise)
logits += noise
if self.use_router_z_loss:
z_loss = self.compute_router_z_loss(logits)
top_k_logits, top_k_indices = torch.topk(logits,k=self.top_k,dim=-1) # [B,T,k]
router_probs = torch.full_like(logits, fill_value=float('-inf')) # [B, T, n_exp]
router_probs = torch.scatter(router_probs, dim = -1, index=top_k_indices, src=top_k_logits)
router_probs = F.softmax(router_probs, dim = -1)
if self.use_aux_loss:
aux_loss = self.compute_aux_loss(router_probs, top_k_indices)
exp_capacity = self.get_capacity(num_tokens)
exp_mask = F.one_hot(top_k_indices, num_classes=self.n_exp) #[B,T,k,n_exp]
exp_mask = exp_mask.view(num_tokens,self.top_k, self.n_exp) # [B*T, k, n_exp]
exp_mask = exp_mask.permute(1,0,2) # [k, B*T, n_exp]
exp_rank = exp_mask.reshape(self.top_k * num_tokens, self.n_exp) # [B*T*k, n_exp]
exp_rank = torch.cumsum(exp_rank, dim=0) - 1 # cumulative sum of exper selections [k*B*T, n_exp]
exp_rank = exp_rank.reshape(self.top_k, num_tokens, self.n_exp) # [k, B*T, n_exp]
# enforce expert capacity
exp_mask *= torch.lt(exp_rank, exp_capacity) # [k, B*T, n_exp]
used_capacity = torch.sum(exp_mask, dim=(0, 1)) # [n_exp]
router_probs = router_probs.view(num_tokens, self.n_exp)[None, :] # [1, B*T, n_exp]
exp_weights = exp_mask * router_probs # [1, num_tokens, n_exp]
exp_rank_sc = F.one_hot(exp_rank, num_classes=exp_capacity) # [k, num_tokens, exp_capacity]
cb_weight = torch.sum(exp_weights.unsqueeze(3) * exp_rank_sc.unsqueeze(2), dim = 0) #[1, num_tokens, n_exp, 1] * [1, num_tokens, 1, exp_capacity] --> [num_tokens, n_exp, exp_capacity]
sec_mask = cb_weight.bool() # binary mask of slected experts for each tokens
'''
used_capacity: [n_exp]
cb_weight: [num_tokens, n_exp, exp_capacity]
sec_mask: [num_tokens, n_exp, exp_capacity]
'''
return used_capacity, cb_weight, sec_mask
class MLP(nn.Module):
def __init__(self, config):
super().__init__()
self.c_fc = nn.Linear(config.n_embd, 4*config.n_embd,bias=config.bias)
self.gelu = nn.GELU()
self.c_proj = nn.Linear(4*config.n_embd, config.n_embd, bias=config.bias)
self.dropout = nn.Dropout(config.dropout)
def forward(self,x):
x = self.c_fc(x)
x = self.gelu(x)
x = self.c_proj(x)
x = self.dropout(x)
return x
class MLPExperts(nn.Module):
def __init__(self, config):
super().__init__()
self.bias = config.bias
# missing weight initialization
self.c_fc = nn.Parameter(torch.empty(config.n_exp, config.n_embd, 4 * config.n_embd))
self.c_proj = nn.Parameter(torch.empty(config.n_exp, 4*config.n_embd, config.n_embd))
self.fc_bias = nn.Parameter(torch.empty(config.n_exp, 1, 4*config.n_embd)) if self.bias else None
self.proj_bias = nn.Parameter(torch.empty(config.n_exp, 1, config.n_embd)) if self.bias else None
self.gelu = nn.GELU()
self.dropout = nn.Dropout(config.dropout)
def forward(self, x):
x = torch.bmm(x, self.c_fc)
if self.bias:
x += self.fc_bias
x = self.gelu(x)
x = torch.bmm(x, self.c_proj)
if self.bias:
x += self.proj_bias
x = self.dropout(x)
return x
class MOELayer(nn.Module):
def __init__(self, config):
super().__init__()
self.router = Router(config)
self.experts = MLPExperts(config)
def forward(self, x):
B, T, n_embd = x.size()
num_tokens = B*T
used_capacity, exp_weight, exp_mask = self.router(x)
# this batching is a clever trick
# where we partition to tokens into their corresponding experts based on router results
x = x.view(num_tokens, n_embd) # [B, T, n_embd] --> [n_tokens, n_embd]
exp_mask = exp_mask.permute(1,2,0).type_as(x) # [num_tokens, n_exp, exp_capacity] --> [n_exp, exp_capacity, num_tokens]
exp_batches = exp_mask @ x # [n_exp, exp_capacity, num_tokens] @ [num_tokens, n_embd] --> [n_exp, exp_capacity, n_embd]
exp_out = self.experts(exp_batches) # [n_exp, exp_capacity, n_embd]
exp_weight = exp_weight.view(num_tokens, -1) #[num_tokens, n_exp*exp_capacity]
exp_out = exp_out.view(-1,n_embd) # [n_exp*exp_capacity, n_embd]
output = exp_weight @ exp_out # [B*T, n_embd]
return output.view(B,T,n_embd)
class Block(nn.Module):
def __init__(self, config, use_moe=False):
super().__init__()
self.ln_1 = LayerNorm(config.n_embd, bias=config.bias)
self.attn = CausalSelfAttention(config)
self.ln_2 = LayerNorm(config.n_embd, bias=config.bias)
if use_moe:
self.mlp = MOELayer(config)
else:
self.mlp = MLP(config)
def forward(self, x):
x = x + self.attn(self.ln_1(x))
x = x + self.mlp(self.ln_2(x))
return x