CPU implementation This is a fun project on my side to create a MoE model


import math
import inspect
from dataclasses import dataclass
from contextlib import nullcontext

import torch
import torch.nn as nn
from torch.nn import functional as F
def compute_aux_loss(self, expert_probs, indices):
    pass

def compute_router_z_loss(self, logits):
    pass

def get_capacity(self, tokens_per_batch):
    pass
class LayerNorm(nn.Module):
    def __init__(self, ndim, bias):
        super().__init__()
        self.ndim = ndim
        self.bias = bias
        self.weight = nn.Parameter(torch.ones(ndim))
        self.bias = nn.Parameter(torch.zeros(ndim)) if bias else None
    
    def forward(self, input):
        return F.layer_norm(input, normalized_shape=self.weight.shape, weight=self.weight, bias=self.bias, eps=1e-5)
    

class CausalSelfAttention(nn.Module):
    def __init__(self, config):
        super().__init__()
        assert config.n_embd % config.n_head == 0
        self.c_attn = nn.Linear(config.n_embd, config.n_embd*3, bias=config.bias)
        self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)
        self.attn_dropout = nn.Dropout(config.dropout)
        self.residual_dropout = nn.Dropout(config.dropout)

        # book keeping
        self.n_embd = config.n_embd
        self.n_head = config.n_head

        # this is to set attention mask
        self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size))
                                        .view(1, 1, config.block_size, config.block_size))
    

    def forward(self, x):
        B, T, C = x.shape

        q, k, v = self.c_attn(x).split(self.n_embd, dim = -1)
        
        q = q.view(B, T, self.n_head, C // self.n_head).transpose(1,2) # the transpose here is for the ease of softmax computation
        k = k.view(B, T, self.n_head, C // self.n_head).transpose(1,2)
        v = v.view(B, T, self.n_head, C // self.n_head).transpose(1,2)

        att = torch.einsum('bhtd,bhzd->bhtz', q,k)/math.sqrt(k.size(-1))
        att = att.masked_fill(self.bias[:,:,:T,:T] == 0, float('-inf'))
        att = torch.softmax(att, dim = -1)
        att = self.attn_dropout(att)
        y = torch.einsum('bhtz,bhzd->bhtd', att, v)

        y = y.transpose(1,2).contiguous().view(B,T,C)

        y = self.residual_dropout(self.c_project(y))

        return y

class Router(nn.Module):

    def __init__(self, config):
        super().__init__()

        # book keeping
        self.top_k = config.top_k
        self.n_exp = config.n_exp
        assert self.top_k >= 1 and self.n_exp >= self.top_k
        self.use_noisy_top_k = config.use_noisy_top_k
        self.train_capacity = config.train_capacity
        self.eval_capacity = config.eval_capacity
        self.min_capacity = config.min_capacity
        self.router_use_full_prec = config.router_use_full_prec

        # aux loss for router
        self.use_aux_loss = config.use_aux_loss
        self.use_router_z_loss = config.use_router_z_loss

        # param
        self.w_g = nn.Linear(config.n_embd, config.n_exp, bias=False) # no bias --> how ever we can apply adhoc bias modification during training for load balancing
        self.w_noise = nn.Linear(config.n_embd, config.n_exp, bias=False)

    def forward(self, x): # x [bs, seq_len, dim]
        B, T, _ = x.size()

        num_tokens = B*T
        logits = self.w_g(x) # [B, T, n_exp]
        if self.use_noisy_top_k:
            noise = F.softplus(self.w_noise(x)) # compute noise std
            noise *= torch.randn_like(noise)
            logits += noise
        
        if self.use_router_z_loss:
            z_loss = self.compute_router_z_loss(logits)

        top_k_logits, top_k_indices = torch.topk(logits,k=self.top_k,dim=-1) # [B,T,k]

        router_probs = torch.full_like(logits, fill_value=float('-inf')) # [B, T, n_exp]
        router_probs = torch.scatter(router_probs, dim = -1, index=top_k_indices, src=top_k_logits)
        router_probs = F.softmax(router_probs, dim = -1)

        if self.use_aux_loss:
            aux_loss = self.compute_aux_loss(router_probs, top_k_indices)
        
        exp_capacity = self.get_capacity(num_tokens)

        exp_mask = F.one_hot(top_k_indices, num_classes=self.n_exp) #[B,T,k,n_exp]
        exp_mask = exp_mask.view(num_tokens,self.top_k, self.n_exp) # [B*T, k, n_exp]
        exp_mask = exp_mask.permute(1,0,2) # [k, B*T, n_exp]

        exp_rank = exp_mask.reshape(self.top_k * num_tokens, self.n_exp) # [B*T*k, n_exp]
        exp_rank = torch.cumsum(exp_rank, dim=0) - 1 # cumulative sum of exper selections [k*B*T, n_exp]
        exp_rank = exp_rank.reshape(self.top_k, num_tokens, self.n_exp) # [k, B*T, n_exp]

        # enforce expert capacity
        exp_mask *= torch.lt(exp_rank, exp_capacity) # [k, B*T, n_exp]
        used_capacity = torch.sum(exp_mask, dim=(0, 1)) # [n_exp]

        router_probs = router_probs.view(num_tokens, self.n_exp)[None, :] # [1, B*T, n_exp]
        exp_weights = exp_mask * router_probs # [1, num_tokens, n_exp]

        exp_rank_sc = F.one_hot(exp_rank, num_classes=exp_capacity) # [k, num_tokens, exp_capacity]

        cb_weight = torch.sum(exp_weights.unsqueeze(3) * exp_rank_sc.unsqueeze(2), dim = 0) #[1, num_tokens, n_exp, 1] * [1, num_tokens, 1, exp_capacity] --> [num_tokens, n_exp, exp_capacity]
        sec_mask = cb_weight.bool() # binary mask of slected experts for each tokens

        '''
        used_capacity: [n_exp]
        cb_weight: [num_tokens, n_exp, exp_capacity]
        sec_mask: [num_tokens, n_exp, exp_capacity]
        '''
        return used_capacity, cb_weight, sec_mask

class MLP(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.c_fc = nn.Linear(config.n_embd, 4*config.n_embd,bias=config.bias)
        self.gelu = nn.GELU()
        self.c_proj = nn.Linear(4*config.n_embd, config.n_embd, bias=config.bias)
        self.dropout = nn.Dropout(config.dropout)

    def forward(self,x):
        x = self.c_fc(x)
        x = self.gelu(x)
        x = self.c_proj(x)
        x = self.dropout(x)
        return x

class MLPExperts(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.bias = config.bias

        # missing weight initialization
        self.c_fc = nn.Parameter(torch.empty(config.n_exp, config.n_embd, 4 * config.n_embd))
        self.c_proj = nn.Parameter(torch.empty(config.n_exp, 4*config.n_embd, config.n_embd))
        self.fc_bias = nn.Parameter(torch.empty(config.n_exp, 1, 4*config.n_embd)) if self.bias else None
        self.proj_bias = nn.Parameter(torch.empty(config.n_exp, 1, config.n_embd)) if self.bias else None
        self.gelu = nn.GELU()
        self.dropout = nn.Dropout(config.dropout)
    
    def forward(self, x):
        x = torch.bmm(x, self.c_fc)
        if self.bias:
            x += self.fc_bias
        x = self.gelu(x)
        x = torch.bmm(x, self.c_proj)
        if self.bias:
            x += self.proj_bias
        x = self.dropout(x)
        return x

class MOELayer(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.router = Router(config)
        self.experts = MLPExperts(config)

    def forward(self, x):
        B, T, n_embd = x.size()
        num_tokens = B*T

        used_capacity, exp_weight, exp_mask = self.router(x)
        
        # this batching is a clever trick
        # where we partition to tokens into their corresponding experts based on router results
        x = x.view(num_tokens, n_embd) # [B, T, n_embd] --> [n_tokens, n_embd]
        exp_mask = exp_mask.permute(1,2,0).type_as(x) # [num_tokens, n_exp, exp_capacity] --> [n_exp, exp_capacity, num_tokens]
        exp_batches = exp_mask @ x # [n_exp, exp_capacity, num_tokens] @ [num_tokens, n_embd] --> [n_exp, exp_capacity, n_embd]

        exp_out = self.experts(exp_batches) # [n_exp, exp_capacity, n_embd]

        exp_weight = exp_weight.view(num_tokens, -1) #[num_tokens, n_exp*exp_capacity]
        exp_out = exp_out.view(-1,n_embd) # [n_exp*exp_capacity, n_embd]
        output = exp_weight @ exp_out # [B*T, n_embd]
        
        return output.view(B,T,n_embd)
    
class Block(nn.Module):
    
    def __init__(self, config, use_moe=False):
        super().__init__()
        self.ln_1 = LayerNorm(config.n_embd, bias=config.bias)
        self.attn = CausalSelfAttention(config)
        self.ln_2 = LayerNorm(config.n_embd, bias=config.bias)
        if use_moe:
            self.mlp = MOELayer(config)
        else:
            self.mlp = MLP(config)
    
    def forward(self, x):
        x = x + self.attn(self.ln_1(x))
        x = x + self.mlp(self.ln_2(x))

        return x