bittensor-subnet-template
bittensor-subnet-template copied to clipboard
Add MoE Gating model base
Create a gating model in a Mixture of Experts (MoE) architecture using PyTorch. We can implement a soft gating mechanism where the weights act as probabilities for selecting different experts. We can use the Gumbel-Softmax trick to sample from the categorical distribution with temperature, making the sampling process differentiable.
This should be part of the validator, as most subnets will want some kind of automatic routing mechanism without having to reinvent the wheel.
import torch
import torch.nn as nn
import torch.nn.functional as F
class GatingModel(nn.Module):
def __init__(self, input_dim, num_experts, temperature=1.0):
super(GatingModel, self).__init__()
self.num_experts = num_experts
self.temperature = temperature
# Gating network
self.gating_network = nn.Sequential(
nn.Linear(input_dim, num_experts),
nn.Softmax(dim=-1) # Softmax along the expert dimension
)
def forward(self, input):
# Calculate gating probabilities
gating_probs = self.gating_network(input)
# Gumbel-Softmax sampling for discrete selection
gumbel_noise = torch.rand_like(gating_probs)
gumbel_noise = -torch.log(-torch.log(gumbel_noise + 1e-20) + 1e-20) # Gumbel noise
logits = (torch.log(gating_probs + 1e-20) + gumbel_noise) / self.temperature
selected_experts = F.softmax(logits, dim=-1)
# Weighted sum of expert outputs
output = torch.sum(selected_experts.unsqueeze(-1) * input.unsqueeze(-2), dim=-2)
return output, selected_experts
# Example usage
input_dim = 10
num_experts = 5
temperature = 0.1
# Create a GatingModel
gating_model = GatingModel(input_dim, num_experts, temperature)
# Generate dummy input
input_data = torch.randn(32, input_dim)
# Forward pass through the gating model
output, selected_experts = gating_model(input_data)
# The 'output' is the final output of the MoE, and 'selected_experts' is the one-hot vector indicating which experts were selected for each example.