bittensor-subnet-template icon indicating copy to clipboard operation
bittensor-subnet-template copied to clipboard

Add MoE Gating model base

Open ifrit98 opened this issue 2 years ago • 0 comments

Create a gating model in a Mixture of Experts (MoE) architecture using PyTorch. We can implement a soft gating mechanism where the weights act as probabilities for selecting different experts. We can use the Gumbel-Softmax trick to sample from the categorical distribution with temperature, making the sampling process differentiable.

This should be part of the validator, as most subnets will want some kind of automatic routing mechanism without having to reinvent the wheel.

import torch
import torch.nn as nn
import torch.nn.functional as F

class GatingModel(nn.Module):
    def __init__(self, input_dim, num_experts, temperature=1.0):
        super(GatingModel, self).__init__()

        self.num_experts = num_experts
        self.temperature = temperature

        # Gating network
        self.gating_network = nn.Sequential(
            nn.Linear(input_dim, num_experts),
            nn.Softmax(dim=-1)  # Softmax along the expert dimension
        )

    def forward(self, input):
        # Calculate gating probabilities
        gating_probs = self.gating_network(input)

        # Gumbel-Softmax sampling for discrete selection
        gumbel_noise = torch.rand_like(gating_probs)
        gumbel_noise = -torch.log(-torch.log(gumbel_noise + 1e-20) + 1e-20)  # Gumbel noise
        logits = (torch.log(gating_probs + 1e-20) + gumbel_noise) / self.temperature
        selected_experts = F.softmax(logits, dim=-1)

        # Weighted sum of expert outputs
        output = torch.sum(selected_experts.unsqueeze(-1) * input.unsqueeze(-2), dim=-2)

        return output, selected_experts

# Example usage
input_dim = 10
num_experts = 5
temperature = 0.1

# Create a GatingModel
gating_model = GatingModel(input_dim, num_experts, temperature)

# Generate dummy input
input_data = torch.randn(32, input_dim)

# Forward pass through the gating model
output, selected_experts = gating_model(input_data)

# The 'output' is the final output of the MoE, and 'selected_experts' is the one-hot vector indicating which experts were selected for each example.

ifrit98 avatar Nov 22 '23 04:11 ifrit98