mamba icon indicating copy to clipboard operation
mamba copied to clipboard

OOM with Mamba-2 on U-Mamba

Open ZagButNoZig opened this issue 8 months ago • 1 comments

I'm playing around with the U-Mamba model and I get out of memory issues when replacing

self.mamba = Mamba(
                d_model=dim, # Model dimension d_model
                d_state=d_state,  # SSM state expansion factor
                d_conv=d_conv,    # Local convolution width
                expand=expand,    # Block expansion factor
        )

with

self.mamba = Mamba2(
                d_model=dim, # Model dimension d_model
                d_state=d_state,  # SSM state expansion factor
                d_conv=d_conv,    # Local convolution width
                expand=expand,    # Block expansion factor
                headdim=d_state,
        )

even though I would expect Mamba-2 to consume less memory.

I wrote a small repro, which gives me:


Mamba 3D (forward)  peak: 6100.38 MB
Mamba 3D (backward) peak: 9516.54 MB
Mamba2 3D (forward)  peak: 6468.64 MB
Mamba2 3D (backward) peak: 12424.77 MB

Code to reproduce

import torch
import torch.nn as nn

# Install via: pip install mamba-ssm
from mamba_ssm import Mamba, Mamba2

class Mamba3DLayer(nn.Module):
    def __init__(self, dim, d_state=16, d_conv=4, expand=2):
        super().__init__()
        self.dim = dim
        self.norm = nn.LayerNorm(dim)
        # Mamba v1 has signature: Mamba(d_model, d_state, d_conv, expand)
        self.mamba = Mamba(
            d_model=dim,
            d_state=d_state,
            d_conv=d_conv,
            expand=expand,
        )

    def forward(self, x):
        B, C = x.shape[:2]
        D, H, W = x.shape[2:]
        # flatten (B, C, D, H, W) → (B, N, C)
        x_flat = x.view(B, C, -1).transpose(1, 2)
        x_norm = self.norm(x_flat)
        y = self.mamba(x_norm)  # (B, N, C)
        # reshape back
        return y.transpose(1, 2).view(B, C, D, H, W)

class Mamba2_3DLayer(nn.Module):
    def __init__(self, dim, d_state=16, d_conv=4, expand=2):
        super().__init__()
        self.dim = dim
        self.norm = nn.LayerNorm(dim)
        # Mamba2 still takes headdim
        self.mamba = Mamba2(
            d_model=dim,
            d_state=d_state,
            d_conv=d_conv,
            expand=expand,
            headdim=d_state,
        )

    def forward(self, x):
        B, C = x.shape[:2]
        D, H, W = x.shape[2:]
        x_flat = x.view(B, C, -1).transpose(1, 2)
        x_norm = self.norm(x_flat)
        y = self.mamba(x_norm)
        return y.transpose(1, 2).view(B, C, D, H, W)

def measure_peak(module: nn.Module, input_tensor: torch.Tensor):
    device = input_tensor.device
    module = module.to(device).train()
    torch.cuda.reset_peak_memory_stats(device)

    out = module(input_tensor)
    torch.cuda.synchronize(device)
    peak_fwd = torch.cuda.max_memory_allocated(device)

    out.sum().backward()
    torch.cuda.synchronize(device)
    peak_bwd = torch.cuda.max_memory_allocated(device)

    return peak_fwd, peak_bwd

if __name__ == "__main__":
    if not torch.cuda.is_available():
        raise RuntimeError("CUDA is required to measure GPU memory usage.")
    device = torch.device("cuda")

    # ---- CONFIGURATION ----
    batch_size = 4
    dim = 64
    depth, height, width = 64, 64, 64

    x = torch.randn(batch_size, dim, depth, height, width,
                    device=device, requires_grad=True)

    layer1 = Mamba3DLayer(dim=dim).to(device)
    layer2 = Mamba2_3DLayer(dim=dim).to(device)

    # Warm-up & clear cache
    _ = layer1(x)
    _ = layer2(x)
    torch.cuda.synchronize(device)
    torch.cuda.empty_cache()

    # Measure
    f1, b1 = measure_peak(layer1, x)
    torch.cuda.empty_cache()
    f2, b2 = measure_peak(layer2, x)

    print(f"Mamba 3D (forward)  peak: {f1/1024**2:.2f} MB")
    print(f"Mamba 3D (backward) peak: {b1/1024**2:.2f} MB")
    print(f"Mamba2 3D (forward)  peak: {f2/1024**2:.2f} MB")
    print(f"Mamba2 3D (backward) peak: {b2/1024**2:.2f} MB")

ZagButNoZig avatar Aug 02 '25 12:08 ZagButNoZig

Can you share the environment details. Running your reproduce code, on Colab, with the same pip install gives me

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
[/tmp/ipython-input-3530614967.py](https://localhost:8080/#) in <cell line: 0>()
     83     # Warm-up & clear cache
     84     _ = layer1(x)
---> 85     _ = layer2(x)
     86     torch.cuda.synchronize(device)
     87     torch.cuda.empty_cache()

9 frames
[/usr/local/lib/python3.12/dist-packages/mamba_ssm/ops/triton/ssd_combined.py](https://localhost:8080/#) in forward(ctx, zxbcdt, conv1d_weight, conv1d_bias, dt_bias, A, D, chunk_size, initial_states, seq_idx, dt_limit, return_final_states, activation, rmsnorm_weight, rmsnorm_eps, outproj_weight, outproj_bias, headdim, ngroups, norm_before_gate)
    790         seq_idx = seq_idx.contiguous() if seq_idx is not None else None
    791         xBC_conv = rearrange(
--> 792             causal_conv1d_fwd_function(rearrange_and_update_stride(xBC, "b s d -> b d s"),
    793                                                  conv1d_weight, conv1d_bias, seq_idx, None, None, activation in ["silu", "swish"]),
    794             "b d s -> b s d"

TypeError: 'NoneType' object is not callable

gayanku avatar Sep 03 '25 01:09 gayanku