mamba
mamba copied to clipboard
OOM with Mamba-2 on U-Mamba
I'm playing around with the U-Mamba model and I get out of memory issues when replacing
self.mamba = Mamba(
d_model=dim, # Model dimension d_model
d_state=d_state, # SSM state expansion factor
d_conv=d_conv, # Local convolution width
expand=expand, # Block expansion factor
)
with
self.mamba = Mamba2(
d_model=dim, # Model dimension d_model
d_state=d_state, # SSM state expansion factor
d_conv=d_conv, # Local convolution width
expand=expand, # Block expansion factor
headdim=d_state,
)
even though I would expect Mamba-2 to consume less memory.
I wrote a small repro, which gives me:
Mamba 3D (forward) peak: 6100.38 MB
Mamba 3D (backward) peak: 9516.54 MB
Mamba2 3D (forward) peak: 6468.64 MB
Mamba2 3D (backward) peak: 12424.77 MB
Code to reproduce
import torch
import torch.nn as nn
# Install via: pip install mamba-ssm
from mamba_ssm import Mamba, Mamba2
class Mamba3DLayer(nn.Module):
def __init__(self, dim, d_state=16, d_conv=4, expand=2):
super().__init__()
self.dim = dim
self.norm = nn.LayerNorm(dim)
# Mamba v1 has signature: Mamba(d_model, d_state, d_conv, expand)
self.mamba = Mamba(
d_model=dim,
d_state=d_state,
d_conv=d_conv,
expand=expand,
)
def forward(self, x):
B, C = x.shape[:2]
D, H, W = x.shape[2:]
# flatten (B, C, D, H, W) → (B, N, C)
x_flat = x.view(B, C, -1).transpose(1, 2)
x_norm = self.norm(x_flat)
y = self.mamba(x_norm) # (B, N, C)
# reshape back
return y.transpose(1, 2).view(B, C, D, H, W)
class Mamba2_3DLayer(nn.Module):
def __init__(self, dim, d_state=16, d_conv=4, expand=2):
super().__init__()
self.dim = dim
self.norm = nn.LayerNorm(dim)
# Mamba2 still takes headdim
self.mamba = Mamba2(
d_model=dim,
d_state=d_state,
d_conv=d_conv,
expand=expand,
headdim=d_state,
)
def forward(self, x):
B, C = x.shape[:2]
D, H, W = x.shape[2:]
x_flat = x.view(B, C, -1).transpose(1, 2)
x_norm = self.norm(x_flat)
y = self.mamba(x_norm)
return y.transpose(1, 2).view(B, C, D, H, W)
def measure_peak(module: nn.Module, input_tensor: torch.Tensor):
device = input_tensor.device
module = module.to(device).train()
torch.cuda.reset_peak_memory_stats(device)
out = module(input_tensor)
torch.cuda.synchronize(device)
peak_fwd = torch.cuda.max_memory_allocated(device)
out.sum().backward()
torch.cuda.synchronize(device)
peak_bwd = torch.cuda.max_memory_allocated(device)
return peak_fwd, peak_bwd
if __name__ == "__main__":
if not torch.cuda.is_available():
raise RuntimeError("CUDA is required to measure GPU memory usage.")
device = torch.device("cuda")
# ---- CONFIGURATION ----
batch_size = 4
dim = 64
depth, height, width = 64, 64, 64
x = torch.randn(batch_size, dim, depth, height, width,
device=device, requires_grad=True)
layer1 = Mamba3DLayer(dim=dim).to(device)
layer2 = Mamba2_3DLayer(dim=dim).to(device)
# Warm-up & clear cache
_ = layer1(x)
_ = layer2(x)
torch.cuda.synchronize(device)
torch.cuda.empty_cache()
# Measure
f1, b1 = measure_peak(layer1, x)
torch.cuda.empty_cache()
f2, b2 = measure_peak(layer2, x)
print(f"Mamba 3D (forward) peak: {f1/1024**2:.2f} MB")
print(f"Mamba 3D (backward) peak: {b1/1024**2:.2f} MB")
print(f"Mamba2 3D (forward) peak: {f2/1024**2:.2f} MB")
print(f"Mamba2 3D (backward) peak: {b2/1024**2:.2f} MB")
Can you share the environment details. Running your reproduce code, on Colab, with the same pip install gives me
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
[/tmp/ipython-input-3530614967.py](https://localhost:8080/#) in <cell line: 0>()
83 # Warm-up & clear cache
84 _ = layer1(x)
---> 85 _ = layer2(x)
86 torch.cuda.synchronize(device)
87 torch.cuda.empty_cache()
9 frames
[/usr/local/lib/python3.12/dist-packages/mamba_ssm/ops/triton/ssd_combined.py](https://localhost:8080/#) in forward(ctx, zxbcdt, conv1d_weight, conv1d_bias, dt_bias, A, D, chunk_size, initial_states, seq_idx, dt_limit, return_final_states, activation, rmsnorm_weight, rmsnorm_eps, outproj_weight, outproj_bias, headdim, ngroups, norm_before_gate)
790 seq_idx = seq_idx.contiguous() if seq_idx is not None else None
791 xBC_conv = rearrange(
--> 792 causal_conv1d_fwd_function(rearrange_and_update_stride(xBC, "b s d -> b d s"),
793 conv1d_weight, conv1d_bias, seq_idx, None, None, activation in ["silu", "swish"]),
794 "b d s -> b s d"
TypeError: 'NoneType' object is not callable