[Feature] Support variable-length sequences for mamba block
Support variable-length sequences for mamba block via cu_seqlens/seq_idx/position_ids in the forward pass and backward pass, similar to what has been done (such as cumulative sequences cu_seqlens or lower triangular block diagonal matrix attention mask) in flash attention varlen_fwd/varlen_bwd API.
We have tested that training with variable-length sequences on real world datasets can bring end-to-end 2~4x speedup.
-
Why we need? High speedup and hardware utilization on real world datasets that we tested. Can be used to improve hardware utilization when you have variable-length sequences and you don't want to waste computing resources on meaningless padded tokens. Especially useful when you do mamba training on real world datasets, where length distribution varies much and large proportion of samples are short sequences. Last but not least, we ensure exact fwd/bwd numerical equality with padding approach.
-
How to use? Zero learning overhead, packed mamba API is similar to packed flash-attn API or packed mamba2 API. Just need to pack multiple variable-length sequences into one and additionally pass
cu_seqlens/seq_idx/position_idsinto mambaforwardpass. -
No need to modify
causal-conv1d, just use the original https://github.com/Dao-AILab/causal-conv1d is fine. (version>=1.4.0)
Note: We thank @wang-zerui for the fwd pass python reference implementation and invaluable discussion on how to ensure numerical equality. This is a joint work with @wang-zerui and @Dmovic and @ptxu78
Example usage: https://github.com/zigzagcai/varlen_mamba/blob/feat/add-cu_seqlens/tests/ops/test_mamba_varlen.py
Limitation:
- This PR currently works well with variable-length training, but variable-length generation (or inference) has not been supported yet.
Some related issues about mamba and flash-attn variable-length training:
- https://github.com/state-spaces/mamba/issues/236
- https://github.com/state-spaces/mamba/issues/356
- https://github.com/state-spaces/mamba/issues/180
- https://github.com/state-spaces/mamba/issues/246#issuecomment-2003017621
- https://github.com/Dao-AILab/flash-attention/issues/850#issuecomment-1980308347
- https://github.com/Dao-AILab/flash-attention/issues/432#issuecomment-1698610752
Hello @tridao @albertfgu
Thanks for the awesome work on mamba and it is really a strong competitor for transformer!
We have noticed some issues (https://github.com/state-spaces/mamba/issues/236, https://github.com/state-spaces/mamba/issues/180) stated that they have a need for training on variable-length sequences. But they can’t find functionalities such as attention_mask or cu_seqlens in mamba block, which are commonly used in transformer structure to support variable length training.
Also, in real world scenarios, length distribution of datasets varies much, simply padding token to maximum length would waste computing resources on the meaningless padded tokens.
So we implemented this PR and hope it helps!
Hello, it's great to see your input on variable length data. How can I use the method you provided? Is there any difference in results between it and padding?
Hello, it's great to see your input on variable length data. How can I use the method you provided? Is there any difference in results between it and padding?
Thank you for your interest in this PR! For the forward pass of mamba block, we have done numerical comparison between it and padding results, which are finally shown to be consistent. (numerical equality for forward pass has been verified) ~~For the backward pass, we decide to add some unit tests to show the consistency when we have bandwidth. (haven't verified numerical equality for backward pass)~~
Update (2024/03/19): Numerical equality for both forward and backward pass have been validated. In terms of training loss and accuracy, this PR can be numerically aligned with padding approach, but can also avoid wasting computation resources on the meaningless padded tokens. When training on a sample dataset, using variable-length training can bring high speedup compared to padding.
Hello, it's great to see your input on variable length data. How can I use the method you provided? Is there any difference in results between it and padding?
Thank you for your interest in this PR! For the forward pass of mamba block, we have done numerical comparison between it and padding results, which are finally shown to be consistent. For the backward pass, we decide to add some unite tests to show the consistency when we have bandwidth.
Thank you for your reply. Due to performance considerations, I would like to use bidirectional mamba. Should I wait for your updated code?
Hello, it's great to see your input on variable length data. How can I use the method you provided? Is there any difference in results between it and padding?
Thank you for your interest in this PR! For the forward pass of mamba block, we have done numerical comparison between it and padding results, which are finally shown to be consistent. For the backward pass, we decide to add some unite tests to show the consistency when we have bandwidth.
Thank you for your reply. Due to performance considerations, I would like to use bidirectional mamba. Should I wait for your updated code?
Hi @EricPaul03 ,
@Dmovic has created unit test on the backward pass of mamba block with variable-length sequences, and the test results show numerical equality for both forward and backward pass in the scenarios of varlen inputs.
I haven't tried it with bidirectional mamba. But since it is numerical equivalent for the default unidirectional mamba, I think you can just give it a try!
To give a simple example. What we originally pass into the original mamba block is an input with shape (batch_size=7, seq_len=10, hidden_dim)
Through this PR, we can instead pass into the variable-length mamba block with an input with shape (batch_size=1, seq_len=32, hidden_dim), where the original variable-length sequences are packed into one fixed-length sequence, with an additional parameter cu_seqlens to mark sequence boundaries.
From the above figure, we can clearly see that through this PR, mamba block can focus computing resources on variable-length sequences and avoid the overhead of meaningless padding tokens.
Variable-length training is very useful for optimizing the hardware utilization during training, and we know that the well-known flash attention has supported variable-length training via cu_seqlens.
Therefore, we believe that mamba, as a competitor of transformer, can improve its hardware utilization during training on real world datasets (the length distribution varies much between data samples) through this PR!
To give a simple example. What we originally pass into the original mamba block is an input with shape
(batch_size=7, seq_len=10, hidden_dim)Through this PR, we can instead pass into the enhanced mamba block with an input with shape(batch_size=1, seq_len=32, hidden_dim), where the original variable-length sequences are packed into one fixed-length sequence, with an additional parametercu_seqlensto mark sequence boundaries.From the above figure, we can clearly see that through this PR, mamba block can focus computing resources on variable-length sequences and avoid the overhead of meaningless padding tokens.
Variable-length training is very useful for optimizing the hardware utilization during training, and we know that the well-known flash attention has supported variable-length training via
cu_seqlens.
Thank you for your answer. This is a great code that I will try to use for my project!
Sorry to bother you again, I would like to implement the same operation for bidirectional mamba. I would like to know if I also need to reset the value for cu_seqlens when flipping the propagation sequence to cope with the flipped sequence, and can these two share d_conv?
for example:
out_f, scan_intermediates_f, out_z_f = selective_scan_cuda.fwd( conv1d_out, delta, A, B, C, D, z, delta_bias, delta_softplus, cu_seqlens )
out_b, scan_intermediates_b, out_z_b = selective_scan_cuda.fwd( conv1d_out.flip([-1]), delta.flip([-1]), A_b, B.flip([-1]), C.flip([-1]), D, z.flip([-1]), delta_bias, delta_softplus, cu_seqlens )# cu_seqlens should be changed??
ctx.save_for_backward(xz, conv1d_weight, conv1d_bias, x_dbl, x_proj_weight, delta_proj_weight, out_proj_weight, conv1d_out, delta, A, A_b, B, C, D, delta_bias, scan_intermediates_f, scan_intermediates_b, out_f, out_b, cu_seqlens, d_conv) #the same d_conv ?
Sorry to bother you again, I would like to implement the same operation for bidirectional mamba. I would like to know if I also need to reset the value for cu_seqlens when flipping the propagation sequence to cope with the flipped sequence, and can these two share d_conv? for example:
out_f, scan_intermediates_f, out_z_f = selective_scan_cuda.fwd( conv1d_out, delta, A, B, C, D, z, delta_bias, delta_softplus, cu_seqlens )
out_b, scan_intermediates_b, out_z_b = selective_scan_cuda.fwd( conv1d_out.flip([-1]), delta.flip([-1]), A_b, B.flip([-1]), C.flip([-1]), D, z.flip([-1]), delta_bias, delta_softplus, cu_seqlens )# cu_seqlens should be changed??
ctx.save_for_backward(xz, conv1d_weight, conv1d_bias, x_dbl, x_proj_weight, delta_proj_weight, out_proj_weight, conv1d_out, delta, A, A_b, B, C, D, delta_bias, scan_intermediates_f, scan_intermediates_b, out_f, out_b, cu_seqlens, d_conv) #the same d_conv ?
I think I should divide conv1d_out, delta, etc. into subsequences and reverse each subsequence separately? (Instead of the entire sequence, use the same cu_seqlens?)
I copy some method in MixerModel to help use this feature.
def unpad_input(self, hidden_states, attention_mask):
hidden_states = rearrange(hidden_states, "b s ... -> (b s) ...")
valid_mask = attention_mask.squeeze(1).squeeze(1).eq(1) # some time is eq(1)
seqlens_in_batch = valid_mask.sum(dim=-1, dtype=torch.int32)
indices = torch.nonzero(valid_mask.flatten(), as_tuple=False).flatten()
max_seqlen_in_batch = seqlens_in_batch.max().item()
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0))
hidden_states = hidden_states[indices].unsqueeze(0)
return hidden_states, indices, cu_seqlens, max_seqlen_in_batch
def pad_input(self, hidden_states, indices, batch, seqlen):
"""
:param hidden_states: Shape is [L,H] not [B,L,H]
:param indices: from unpad_input return indices
:param batch:
:param seqlen: from unpad_input return max_seqlen_in_batch
:return:
"""
output = torch.zeros(batch * seqlen, *hidden_states.shape[1:], device=hidden_states.device,dtype=hidden_states.dtype)
output[indices] = hidden_states
return rearrange(output, '(b s) ... -> b s ...', b=batch)
Sorry to bother you again, I would like to implement the same operation for bidirectional mamba. I would like to know if I also need to reset the value for cu_seqlens when flipping the propagation sequence to cope with the flipped sequence, and can these two share d_conv? for example:
out_f, scan_intermediates_f, out_z_f = selective_scan_cuda.fwd( conv1d_out, delta, A, B, C, D, z, delta_bias, delta_softplus, cu_seqlens )out_b, scan_intermediates_b, out_z_b = selective_scan_cuda.fwd( conv1d_out.flip([-1]), delta.flip([-1]), A_b, B.flip([-1]), C.flip([-1]), D, z.flip([-1]), delta_bias, delta_softplus, cu_seqlens )# cu_seqlens should be changed??ctx.save_for_backward(xz, conv1d_weight, conv1d_bias, x_dbl, x_proj_weight, delta_proj_weight, out_proj_weight, conv1d_out, delta, A, A_b, B, C, D, delta_bias, scan_intermediates_f, scan_intermediates_b, out_f, out_b, cu_seqlens, d_conv) #the same d_conv ?I think I should divide conv1d_out, delta, etc. into subsequences and reverse each subsequence separately? (Instead of the entire sequence, use the same cu_seqlens?)
For bidirectional mamba, you need to pass in the reverse_cu_seqlens to the reverse pass like that,
out_rev = self.mamba_rev(
hidden_states.flip(dims=(1,)), # Flip along the sequence length dimension
cu_seqlens=reverse_cu_seqlens, # Reverse cu_seqlens
inference_params=inference_params
).flip(dims=(1,)) # Flip back for combining with forward hidden states
For example, if you have cu_seqlens = torch.tensor([0, 5, 15, 18, 19, 21, 26, 32]), the reverse_cu_seqlens should be reverse_cu_seqlens = tensor([ 0, 6, 11, 13, 14, 17, 27, 32]), which represents the position in the reverse pass that we need to reset hidden_states.
We can calculate reverse_cu_seqlens as following formula,
reverse_cu_seqlens = torch.cumsum(torch.cat((torch.tensor([0]), (cu_seqlens[1:]-cu_seqlens[:-1]).flip(dims=(0,))), dim=0), dim=0)
Sorry to bother you again, I would like to implement the same operation for bidirectional mamba. I would like to know if I also need to reset the value for cu_seqlens when flipping the propagation sequence to cope with the flipped sequence, and can these two share d_conv? for example:
out_f, scan_intermediates_f, out_z_f = selective_scan_cuda.fwd( conv1d_out, delta, A, B, C, D, z, delta_bias, delta_softplus, cu_seqlens )out_b, scan_intermediates_b, out_z_b = selective_scan_cuda.fwd( conv1d_out.flip([-1]), delta.flip([-1]), A_b, B.flip([-1]), C.flip([-1]), D, z.flip([-1]), delta_bias, delta_softplus, cu_seqlens )# cu_seqlens should be changed??ctx.save_for_backward(xz, conv1d_weight, conv1d_bias, x_dbl, x_proj_weight, delta_proj_weight, out_proj_weight, conv1d_out, delta, A, A_b, B, C, D, delta_bias, scan_intermediates_f, scan_intermediates_b, out_f, out_b, cu_seqlens, d_conv) #the same d_conv ?I think I should divide conv1d_out, delta, etc. into subsequences and reverse each subsequence separately? (Instead of the entire sequence, use the same cu_seqlens?)
I think you might not need to divide these items into subsequences. All you need is to pass in the reverse_cu_seqlens to the reverse pass, and finally enjoys the benefits of both bidirectional and variable-length training.
For combining the benefits of bidirectional mamba and this PR's variable-length sequences, I drew my graphical understanding here,
The mechanism can be simply viewed as that when scanning bidirectionally, hidden_states need to be reset on sequence boundaries of both directions.
Thank you so much, I also have another strange error.
Sorry to bother you again, I would like to implement the same operation for bidirectional mamba. I would like to know if I also need to reset the value for cu_seqlens when flipping the propagation sequence to cope with the flipped sequence, and can these two share d_conv? for example:
out_f, scan_intermediates_f, out_z_f = selective_scan_cuda.fwd( conv1d_out, delta, A, B, C, D, z, delta_bias, delta_softplus, cu_seqlens )out_b, scan_intermediates_b, out_z_b = selective_scan_cuda.fwd( conv1d_out.flip([-1]), delta.flip([-1]), A_b, B.flip([-1]), C.flip([-1]), D, z.flip([-1]), delta_bias, delta_softplus, cu_seqlens )# cu_seqlens should be changed??ctx.save_for_backward(xz, conv1d_weight, conv1d_bias, x_dbl, x_proj_weight, delta_proj_weight, out_proj_weight, conv1d_out, delta, A, A_b, B, C, D, delta_bias, scan_intermediates_f, scan_intermediates_b, out_f, out_b, cu_seqlens, d_conv) #the same d_conv ?I think I should divide conv1d_out, delta, etc. into subsequences and reverse each subsequence separately? (Instead of the entire sequence, use the same cu_seqlens?)
I think you might not need to divide these items into subsequences. All you need is to pass in the
reverse_cu_seqlensto the reverse pass, and finally enjoys the benefits of both bidirectional and variable-length training.For combining the benefits of bidirectional mamba and this PR's variable-length sequences, I drew my graphical understanding here,
The mechanism can be simply viewed as that when scanning bidirectionally, hidden_states need to be reset on sequence boundaries of both directions.
Thank you so much! I will immediately try your excellent code in my project.
Hi @zigzagcai, thank you for the very helpful code!
I've been playing around with it but struggling to get generation to work properly. Namely, I'm packing sequences to (1, sum(seq_len), d) for the first step, but how would I go about reshaping the inference_params cache for predicting new tokens?
As far as I can see allocate_inference_cache in the Mamba class doesn't use max_seqlen so there doesn't seem to be a trivial way to do this.
Hope this was clear. Thanks in advance.
Through this implementation, I'm curious how can these repo used in class MixerModel?
Through this implementation, I'm curious how can these repo used in class
MixerModel?
Sorry for my late response.
I have added the support for cu_seqlens in class MixerModel, to enable the capabilities for training over variable-length sequences.
If you want to use this PR for class MixerModel, all you need to do is some code modification in the collate_fn of your dataloader. That is, packing the original batched variable-length input_ids into one packed input_ids, a.k.a batch_size=1, and provides cu_seqlens(which records start indexes of each sub-sequence in the packed sequence) in your collate_fn return.
In this way, variable length training will work with the forward/backward pass of MixerModel. As a benefit, it can avoid wasting computation resources on padding tokens and improve training efficiency a lot.
I have also added a test to prove the mathematical quivalence with and without cu_seqlens for the mamba block. Hope it helps!
import random
import torch
from mamba_ssm.modules.mamba_simple import Mamba
'''
unpack function: convert packed_hidden_states (batch_size=1) to hidden_states
'''
def unpack(packed_hidden_states, cu_seqlens):
batch_size = cu_seqlens.shape[0] - 1
seq_len = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
hidden_dim = packed_hidden_states.shape[2]
hidden_states = torch.zeros(batch_size, seq_len, hidden_dim, dtype=packed_hidden_states.dtype, device=packed_hidden_states.device)
for i in range(batch_size):
hidden_states[i, : cu_seqlens[i + 1] - cu_seqlens[i], :] = packed_hidden_states[:, cu_seqlens[i] : cu_seqlens[i + 1], :]
return hidden_states
'''
pack function: convert hidden_states to packed_hidden_states (batch_size=1)
'''
def pack(hidden_states, cu_seqlens):
batch_size, seq_len, hidden_dim = hidden_states.shape
seq_len_list = cu_seqlens[1:] - cu_seqlens[:-1]
seq_len_list_3d = seq_len_list.unsqueeze(1).unsqueeze(2)
indices_3d = (
torch.arange(seq_len, device=hidden_states.device)
.unsqueeze(0)
.unsqueeze(2)
.repeat(batch_size, 1, hidden_dim)
)
mask_3d = indices_3d < seq_len_list_3d
packed_hidden_states = hidden_states[mask_3d].view(-1, hidden_dim)
return packed_hidden_states
'''
Generate random cu_seqlens for testing
'''
def generate_random_cu_seqlens(seq_len, batch_size):
if batch_size > 1:
ret = sorted(random.sample(range(1, seq_len), batch_size - 1))
else:
ret = []
cu_seqlens = [0] + ret + [seq_len]
assert batch_size == len(cu_seqlens) - 1
return cu_seqlens
def main():
# config tested with A100
hidden_dim = 2048
seq_len = 1024
batch_size = 8
device='cuda'
itype = torch.float32
rtol, atol = (6e-4, 2e-3) if itype == torch.float32 else (3e-3, 5e-3)
if itype == torch.bfloat16:
rtol, atol = 3e-2, 5e-2
rtolw, atolw = (1e-3, 1e-3)
# If we have z, the errors on the weights seem higher
rtolw = max(rtolw, rtol)
atolw = max(atolw, atol)
# Generate random cu_seqlens for testing
cu_seqlens = generate_random_cu_seqlens(seq_len, batch_size)
cu_seqlens = torch.tensor(cu_seqlens, device=device)
print(f'Generate random cu_seqlens = {cu_seqlens.tolist()}')
# Generate packed_hidden_states with random values for testing
# packed_hidden_states (batch_size=1) should be forwarded with cu_seqlens
hidden_states_list = [torch.randn(l, hidden_dim, device=device) for l in (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()]
packed_hidden_states = torch.cat(hidden_states_list, dim=0).unsqueeze(0)
# hidden_states should be forwarded without cu_seqlens
hidden_states = unpack(packed_hidden_states, cu_seqlens)
# Check: sum of seq_len of item in hidden_states_list should be equal to seq_len of packed_hidden_states
assert sum([hs.shape[0] for hs in hidden_states_list]) == packed_hidden_states.shape[1]
# Check: max of seq_len of item in hidden_states_list should be equal to seq_len of hidden_states
assert max([hs.shape[0] for hs in hidden_states_list]) == hidden_states.shape[1]
# creat one simple mamba block
mamba = Mamba(
# This module uses roughly 3 * expand * d_model^2 parameters
d_model=hidden_dim, # Model dimension d_model
d_state=16, # SSM state expansion factor
d_conv=4, # Local convolution width
expand=2, # Block expansion factor
).to(device)
# reference output for forwardding hidden_states
out_ref = mamba(hidden_states)
out_ref = pack(out_ref, cu_seqlens).unsqueeze(0)
# output for forwardding packed_hidden_states with cu_seqlens
out = mamba(packed_hidden_states, cu_seqlens)
# Testing the max/mean diff
print(f'Output max diff: {(out - out_ref).abs().max().item()}')
print(f'Output mean diff: {(out - out_ref).abs().mean().item()}')
assert torch.allclose(out, out_ref, rtol=rtol, atol=atol)
if __name__ == "__main__":
main()
Hi @zigzagcai, thank you for the very helpful code!
I've been playing around with it but struggling to get generation to work properly. Namely, I'm packing sequences to
(1, sum(seq_len), d)for the first step, but how would I go about reshaping theinference_paramscache for predicting new tokens?As far as I can see
allocate_inference_cachein the Mamba class doesn't use max_seqlen so there doesn't seem to be a trivial way to do this.Hope this was clear. Thanks in advance.
Sorry for my late response.
This PR currently only supports variable-legnth sequences under training scenarios. And we have verified the numerical equivalence and training efficiency on real world datasets.
For variable-legnth sequences under inference scenarios, we haven't supported it yet, but we would try to make it work when we have bandwidth.
@zigzagcai Thanks for your excellent acceleration work! But it encountered an error when I tried to run the test code. Do you know what happened and how to solve this problem?
(cu_mamba) andssy@XX:~$ python test_official.py
Generate random cu_seqlens = [0, 43, 129, 286, 508, 779, 949, 987, 1024]
Traceback (most recent call last):
File "test_official.py", line 107, in <module>
main()
File "test_official.py", line 94, in main
out_ref = mamba(hidden_states)
File "/miniconda3/envs/cu_mamba/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/miniconda3/envs/cu_mamba/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "/cu_mamba/mamba_ssm/modules/mamba_simple.py", line 147, in forward
out = mamba_inner_fn(
File "/cu_mamba/mamba_ssm/ops/selective_scan_interface.py", line 365, in mamba_inner_fn
return MambaInnerFn.apply(xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight,
File "/miniconda3/envs/cu_mamba/lib/python3.8/site-packages/torch/autograd/function.py", line 539, in apply
return super().apply(*args, **kwargs) # type: ignore[misc]
File "/miniconda3/envs/cu_mamba/lib/python3.8/site-packages/torch/cuda/amp/autocast_mode.py", line 113, in decorate_fwd
return fwd(*args, **kwargs)
File "/cu_mamba/mamba_ssm/ops/selective_scan_interface.py", line 250, in forward
out, scan_intermediates, out_z = selective_scan_cuda.fwd(
TypeError: fwd(): incompatible function arguments. The following argument types are supported:
1. (arg0: torch.Tensor, arg1: torch.Tensor, arg2: torch.Tensor, arg3: torch.Tensor, arg4: torch.Tensor, arg5: Optional[torch.Tensor], arg6: Optional[torch.Tensor], arg7: Optional[torch.Tensor], arg8: bool) -> List[torch.Tensor]
Invoked with: tensor([[[ 0.0886, -0.0034, 0.0710, ..., 0.0726, 0.0726, 0.0726],
......
the version of causal-conv1d and mamba is 1.2.0.post1
Through this implementation, I'm curious how can these repo used in class
MixerModel?Sorry for my late response.
I have added the support for
cu_seqlensin classMixerModel, to enable the capabilities for training over variable-length sequences.If you want to use this PR for class
MixerModel, all you need to do is some code modification in the collate_fn of your dataloader. That is, packing the originalbatched variable-length input_idsintoone packed input_ids, a.k.abatch_size=1, and providescu_seqlens(which records start indexes of each sub-sequence in the packed sequence) in your collate_fn return.In this way, variable length training will work with the forward/backward pass of
MixerModel. As a benefit, it can avoid wasting computation resources on padding tokens and improve training efficiency a lot.I have also added a test to prove the mathematical quivalence with and without
cu_seqlensfor the mamba block. Hope it helps!import random import torch from mamba_ssm.modules.mamba_simple import Mamba ''' unpack function: convert packed_hidden_states (batch_size=1) to hidden_states ''' def unpack(packed_hidden_states, cu_seqlens): batch_size = cu_seqlens.shape[0] - 1 seq_len = (cu_seqlens[1:] - cu_seqlens[:-1]).max() hidden_dim = packed_hidden_states.shape[2] hidden_states = torch.zeros(batch_size, seq_len, hidden_dim, dtype=packed_hidden_states.dtype, device=packed_hidden_states.device) for i in range(batch_size): hidden_states[i, : cu_seqlens[i + 1] - cu_seqlens[i], :] = packed_hidden_states[:, cu_seqlens[i] : cu_seqlens[i + 1], :] return hidden_states ''' pack function: convert hidden_states to packed_hidden_states (batch_size=1) ''' def pack(hidden_states, cu_seqlens): batch_size, seq_len, hidden_dim = hidden_states.shape seq_len_list = cu_seqlens[1:] - cu_seqlens[:-1] seq_len_list_3d = seq_len_list.unsqueeze(1).unsqueeze(2) indices_3d = ( torch.arange(seq_len, device=hidden_states.device) .unsqueeze(0) .unsqueeze(2) .repeat(batch_size, 1, hidden_dim) ) mask_3d = indices_3d < seq_len_list_3d packed_hidden_states = hidden_states[mask_3d].view(-1, hidden_dim) return packed_hidden_states ''' Generate random cu_seqlens for testing ''' def generate_random_cu_seqlens(seq_len, batch_size): if batch_size > 1: ret = sorted(random.sample(range(1, seq_len), batch_size - 1)) else: ret = [] cu_seqlens = [0] + ret + [seq_len] assert batch_size == len(cu_seqlens) - 1 return cu_seqlens def main(): # config tested with A100 hidden_dim = 2048 seq_len = 1024 batch_size = 8 device='cuda' itype = torch.float32 rtol, atol = (6e-4, 2e-3) if itype == torch.float32 else (3e-3, 5e-3) if itype == torch.bfloat16: rtol, atol = 3e-2, 5e-2 rtolw, atolw = (1e-3, 1e-3) # If we have z, the errors on the weights seem higher rtolw = max(rtolw, rtol) atolw = max(atolw, atol) # Generate random cu_seqlens for testing cu_seqlens = generate_random_cu_seqlens(seq_len, batch_size) cu_seqlens = torch.tensor(cu_seqlens, device=device) print(f'Generate random cu_seqlens = {cu_seqlens.tolist()}') # Generate packed_hidden_states with random values for testing # packed_hidden_states (batch_size=1) should be forwarded with cu_seqlens hidden_states_list = [torch.randn(l, hidden_dim, device=device) for l in (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()] packed_hidden_states = torch.cat(hidden_states_list, dim=0).unsqueeze(0) # hidden_states should be forwarded without cu_seqlens hidden_states = unpack(packed_hidden_states, cu_seqlens) # Check: sum of seq_len of item in hidden_states_list should be equal to seq_len of packed_hidden_states assert sum([hs.shape[0] for hs in hidden_states_list]) == packed_hidden_states.shape[1] # Check: max of seq_len of item in hidden_states_list should be equal to seq_len of hidden_states assert max([hs.shape[0] for hs in hidden_states_list]) == hidden_states.shape[1] # creat one simple mamba block mamba = Mamba( # This module uses roughly 3 * expand * d_model^2 parameters d_model=hidden_dim, # Model dimension d_model d_state=16, # SSM state expansion factor d_conv=4, # Local convolution width expand=2, # Block expansion factor ).to(device) # reference output for forwardding hidden_states out_ref = mamba(hidden_states) out_ref = pack(out_ref, cu_seqlens).unsqueeze(0) # output for forwardding packed_hidden_states with cu_seqlens out = mamba(packed_hidden_states, cu_seqlens) # Testing the max/mean diff print(f'Output max diff: {(out - out_ref).abs().max().item()}') print(f'Output mean diff: {(out - out_ref).abs().mean().item()}') assert torch.allclose(out, out_ref, rtol=rtol, atol=atol) if __name__ == "__main__": main()
@zigzagcai Thanks for your excellent acceleration work! But it encountered an error when I tried to run the test code. Do you know what happened and how to solve this problem?
(cu_mamba) andssy@XX:~$ python test_official.py Generate random cu_seqlens = [0, 43, 129, 286, 508, 779, 949, 987, 1024] Traceback (most recent call last): File "test_official.py", line 107, in <module> main() File "test_official.py", line 94, in main out_ref = mamba(hidden_states) File "/miniconda3/envs/cu_mamba/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl return self._call_impl(*args, **kwargs) File "/miniconda3/envs/cu_mamba/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl return forward_call(*args, **kwargs) File "/cu_mamba/mamba_ssm/modules/mamba_simple.py", line 147, in forward out = mamba_inner_fn( File "/cu_mamba/mamba_ssm/ops/selective_scan_interface.py", line 365, in mamba_inner_fn return MambaInnerFn.apply(xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight, File "/miniconda3/envs/cu_mamba/lib/python3.8/site-packages/torch/autograd/function.py", line 539, in apply return super().apply(*args, **kwargs) # type: ignore[misc] File "/miniconda3/envs/cu_mamba/lib/python3.8/site-packages/torch/cuda/amp/autocast_mode.py", line 113, in decorate_fwd return fwd(*args, **kwargs) File "/cu_mamba/mamba_ssm/ops/selective_scan_interface.py", line 250, in forward out, scan_intermediates, out_z = selective_scan_cuda.fwd( TypeError: fwd(): incompatible function arguments. The following argument types are supported: 1. (arg0: torch.Tensor, arg1: torch.Tensor, arg2: torch.Tensor, arg3: torch.Tensor, arg4: torch.Tensor, arg5: Optional[torch.Tensor], arg6: Optional[torch.Tensor], arg7: Optional[torch.Tensor], arg8: bool) -> List[torch.Tensor] Invoked with: tensor([[[ 0.0886, -0.0034, 0.0710, ..., 0.0726, 0.0726, 0.0726], ......the version of causal-conv1d and mamba is
1.2.0.post1Through this implementation, I'm curious how can these repo used in class
MixerModel?Sorry for my late response. I have added the support for
cu_seqlensin classMixerModel, to enable the capabilities for training over variable-length sequences. If you want to use this PR for classMixerModel, all you need to do is some code modification in the collate_fn of your dataloader. That is, packing the originalbatched variable-length input_idsintoone packed input_ids, a.k.abatch_size=1, and providescu_seqlens(which records start indexes of each sub-sequence in the packed sequence) in your collate_fn return. In this way, variable length training will work with the forward/backward pass ofMixerModel. As a benefit, it can avoid wasting computation resources on padding tokens and improve training efficiency a lot. I have also added a test to prove the mathematical quivalence with and withoutcu_seqlensfor the mamba block. Hope it helps!import random import torch from mamba_ssm.modules.mamba_simple import Mamba ''' unpack function: convert packed_hidden_states (batch_size=1) to hidden_states ''' def unpack(packed_hidden_states, cu_seqlens): batch_size = cu_seqlens.shape[0] - 1 seq_len = (cu_seqlens[1:] - cu_seqlens[:-1]).max() hidden_dim = packed_hidden_states.shape[2] hidden_states = torch.zeros(batch_size, seq_len, hidden_dim, dtype=packed_hidden_states.dtype, device=packed_hidden_states.device) for i in range(batch_size): hidden_states[i, : cu_seqlens[i + 1] - cu_seqlens[i], :] = packed_hidden_states[:, cu_seqlens[i] : cu_seqlens[i + 1], :] return hidden_states ''' pack function: convert hidden_states to packed_hidden_states (batch_size=1) ''' def pack(hidden_states, cu_seqlens): batch_size, seq_len, hidden_dim = hidden_states.shape seq_len_list = cu_seqlens[1:] - cu_seqlens[:-1] seq_len_list_3d = seq_len_list.unsqueeze(1).unsqueeze(2) indices_3d = ( torch.arange(seq_len, device=hidden_states.device) .unsqueeze(0) .unsqueeze(2) .repeat(batch_size, 1, hidden_dim) ) mask_3d = indices_3d < seq_len_list_3d packed_hidden_states = hidden_states[mask_3d].view(-1, hidden_dim) return packed_hidden_states ''' Generate random cu_seqlens for testing ''' def generate_random_cu_seqlens(seq_len, batch_size): if batch_size > 1: ret = sorted(random.sample(range(1, seq_len), batch_size - 1)) else: ret = [] cu_seqlens = [0] + ret + [seq_len] assert batch_size == len(cu_seqlens) - 1 return cu_seqlens def main(): # config tested with A100 hidden_dim = 2048 seq_len = 1024 batch_size = 8 device='cuda' itype = torch.float32 rtol, atol = (6e-4, 2e-3) if itype == torch.float32 else (3e-3, 5e-3) if itype == torch.bfloat16: rtol, atol = 3e-2, 5e-2 rtolw, atolw = (1e-3, 1e-3) # If we have z, the errors on the weights seem higher rtolw = max(rtolw, rtol) atolw = max(atolw, atol) # Generate random cu_seqlens for testing cu_seqlens = generate_random_cu_seqlens(seq_len, batch_size) cu_seqlens = torch.tensor(cu_seqlens, device=device) print(f'Generate random cu_seqlens = {cu_seqlens.tolist()}') # Generate packed_hidden_states with random values for testing # packed_hidden_states (batch_size=1) should be forwarded with cu_seqlens hidden_states_list = [torch.randn(l, hidden_dim, device=device) for l in (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()] packed_hidden_states = torch.cat(hidden_states_list, dim=0).unsqueeze(0) # hidden_states should be forwarded without cu_seqlens hidden_states = unpack(packed_hidden_states, cu_seqlens) # Check: sum of seq_len of item in hidden_states_list should be equal to seq_len of packed_hidden_states assert sum([hs.shape[0] for hs in hidden_states_list]) == packed_hidden_states.shape[1] # Check: max of seq_len of item in hidden_states_list should be equal to seq_len of hidden_states assert max([hs.shape[0] for hs in hidden_states_list]) == hidden_states.shape[1] # creat one simple mamba block mamba = Mamba( # This module uses roughly 3 * expand * d_model^2 parameters d_model=hidden_dim, # Model dimension d_model d_state=16, # SSM state expansion factor d_conv=4, # Local convolution width expand=2, # Block expansion factor ).to(device) # reference output for forwardding hidden_states out_ref = mamba(hidden_states) out_ref = pack(out_ref, cu_seqlens).unsqueeze(0) # output for forwardding packed_hidden_states with cu_seqlens out = mamba(packed_hidden_states, cu_seqlens) # Testing the max/mean diff print(f'Output max diff: {(out - out_ref).abs().max().item()}') print(f'Output mean diff: {(out - out_ref).abs().mean().item()}') assert torch.allclose(out, out_ref, rtol=rtol, atol=atol) if __name__ == "__main__": main()
Hello, thank you for your interest in this PR!
Have you recompiled and installed the mamba package? Since this PR makes some modifications to the selective scan cuda kernel, mamba package needs to be recompiled and installed from source, like that
pip install -e .
You can just try my branch here, which has been synced with latest main branch
git clone https://github.com/zigzagcai/mamba.git --branch feat/add-cu_seqlens
Also, my causal-conv1d version is causal-conv1d==1.2.0.post2
When your environment is ready and you run the test code, the output might be like below,
Generate random cu_seqlens = [0, 49, 138, 206, 224, 360, 438, 993, 1024]
Output max diff: 6.631016731262207e-07
Output mean diff: 3.7819994247456634e-08
The output indicates the mathematical equivalence with and without cu_seqlens for the mamba block.
Besides, we also provided unittest to guarantee the fwd-bwd correctness of our modifications to the selective scan cuda kernel.
$ pytest tests/
============================= test session starts ==============================
platform linux -- Python 3.10.14, pytest-8.1.1, pluggy-1.4.0
rootdir: /home/zigzagcai/mamba
plugins: typeguard-3.0.2
collected 86 items
tests/ops/test_selective_scan.py .................... [ 23%]
tests/ops/test_selective_scan_var_len.py ............ [ 37%]
tests/ops/triton/test_selective_state_update.py ........................ [ 65%]
.............................. [100%]
============================= 86 passed in 32.31s ==============================
@zigzagcai Thanks! I run it successfully. Through this work, I am able to increase the batch size without worrying about GPU memory OOM. One last question, how should I cite your excellent work?
@zigzagcai Thanks! I run it successfully. Through this work, I am able to increase the batch size without worrying about GPU memory OOM. One last question, how should I cite your excellent work?
Very glad to see that it is helpful to you! :D
It's fine to cite the URL of this PR or cite my forked repo for simplicity.
Also, if this feature got merged by the maintainers, you can just cite the official Mamba repo.
@zigzagcai Hi, thanks for your contribution!
I notice that the unpack function above involve copying from the original packed hidden states into the new tensor iteratively through each item in the batch, which will be quite slow when the batch size is large (this effect is quite noticeable in my current project). So, I modified the unpack function to use indexing instead of copying, it works by aligning the starting indices of each batch with the cu_seqlens.
def unpack(packed_hidden_states, cu_seqlens):
batch_size = cu_seqlens.shape[0] - 1
seq_len = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
packed_hidden_states = packed_hidden_states.squeeze(0)
ori_indices = (
torch.arange(seq_len, device=cu_seqlens.device)
.unsqueeze(0)
.expand((batch_size, seq_len))
)
ori_indices = (ori_indices + cu_seqlens[:-1].unsqueeze(1)) % (
len(packed_hidden_states)
)
return packed_hidden_states[ori_indices]
As for performance, I also use the same technique to implement a bidirectional mamba without using cu_seqlens, so there shouldn't be any slowdown even if you unpack and repack the hidden states every layer. The only downside to this is instead of being padded with 0, the new hidden states will be padded with random index. But it's fine since no one compute loss on the pad tensor anyway.
@zigzagcai Hi, thanks for your contribution! I notice that the
unpackfunction above involve copying from the original packed hidden states into the new tensor iteratively through each item in the batch, which will be quite slow when the batch size is large (this effect is quite noticeable in my current project). So, I modified theunpackfunction to use indexing instead of copying, it works by aligning the starting indices of each batch with the cu_seqlens.def unpack(packed_hidden_states, cu_seqlens): batch_size = cu_seqlens.shape[0] - 1 seq_len = (cu_seqlens[1:] - cu_seqlens[:-1]).max() packed_hidden_states = packed_hidden_states.squeeze(0) ori_indices = ( torch.arange(seq_len, device=cu_seqlens.device) .unsqueeze(0) .expand((batch_size, seq_len)) ) ori_indices = (ori_indices + cu_seqlens[:-1].unsqueeze(1)) % ( len(packed_hidden_states) ) return packed_hidden_states[ori_indices]As for performance, I also use the same technique to implement a bidirectional mamba without using cu_seqlens, so there shouldn't be any slowdown even if you unpack and repack the hidden states every layer. The only downside to this is instead of being padded with 0, the new hidden states will be padded with random index. But it's fine since no one compute loss on the pad tensor anyway.
Hi @Museum7432 , Thanks for the sharing, I think your unpack function is better!
Since in our project, the data packing operation is conducted in the dataloader, which is a seperate process from the training process, so it has almost no side-effect on the training efficiency in our experiment.
But for other scenarios, if the users want better efficiency, your provided unpack function is a better choice!
Hi, first off great work. I have not tested it out myself yet, but browsing through the code and working out the mental model. Since the ssm_op does ab1.x * ab0.x, ab1.x * ab0.y + ab1.ysource, how does setting ab1.x to zero not have an adverse effect in any of the following values? I do not have CUDA programming experience, so take this with a pinch of salt please., my mental model might be incorrect from studying the code.
In the binary operator, x is accummulating the "exponent" if you thing of an exponential moving average calculation. Setting one of them to zero now means the rest of the exponents will also be zero, correct?
My assumption is that the CUB scan operation relies on spreading the data into threads, then the code does an additional split across blocks and the results of the blocks are handled using the callback. I am also assuming the the inclusive scan does not guarantee order of execution because of the spread across threads (I might be wrong), but if it did than the following example is still relevant:
Say the input is the following 8 tuples.
[(2,3),(4,1),(1,5),(3,2),(6,7),(5,9),(8,0),(9,6)]
Assume two blocks of 4 threads each, and that we are resetting after the 2nd thread.
Is not the resetting along with the operator going to cause a data corruption?
[(2,3),(8,5),(0,5),(0,5),(0,5),(0,5),(0,0),(0,0)]
I do not see any additional handling of these. The correct result would have instead on the point of reset used a different operator. For example if we assume the data would be processed sequentially across a block.
Without reset:
(a, b) -> (b.x*a.x, b.x*a.y + b.y)
With reset:
(a, b) -> (1*b.x, 0 + b.y)
This can be packed algebraically into a single operator, but it would not be associative.
I have read that you included tests, which I have not studied in detail yet, but I am wondering if you could shine some light into how does purely resetting A to 0 not cause a data corruption?
Hi, first off great work. I have not tested it out myself yet, but browsing through the code and working out the mental model. Since the ssm_op does
ab1.x * ab0.x, ab1.x * ab0.y + ab1.ysource, how does setting ab1.x to zero not have an adverse effect in any of the following values? I do not have CUDA programming experience, so take this with a pinch of salt please., my mental model might be incorrect from studying the code.In the binary operator,
xis accummulating the "exponent" if you thing of an exponential moving average calculation. Setting one of them to zero now means the rest of the exponents will also be zero, correct?My assumption is that the CUB scan operation relies on spreading the data into threads, then the code does an additional split across blocks and the results of the blocks are handled using the callback. I am also assuming the the inclusive scan does not guarantee order of execution because of the spread across threads (I might be wrong), but if it did than the following example is still relevant:
Say the input is the following 8 tuples.
[(2,3),(4,1),(1,5),(3,2),(6,7),(5,9),(8,0),(9,6)]Assume two blocks of 4 threads each, and that we are resetting after the 2nd thread.
Is not the resetting along with the operator going to cause a data corruption?
[(2,3),(8,5),(0,5),(0,5),(0,5),(0,5),(0,0),(0,0)]I do not see any additional handling of these. The correct result would have instead on the point of reset used a different operator. For example if we assume the data would be processed sequentially across a block.
Without reset:
(a, b) -> (b.x*a.x, b.x*a.y + b.y)With reset:(a, b) -> (1*b.x, 0 + b.y)This can be packed algebraically into a single operator, but it would not be associative.
I have read that you included tests, which I have not studied in detail yet, but I am wondering if you could shine some light into how does purely resetting A to 0 not cause a data corruption?
Hi @PheelaV , Thank you for your interest in this PR!
The mathematical proof of mamba varlen fwd/bwd is somehow complicated. For those who are interested about how to verify the mathematical correctness of fwd/bwd pass in selective scan kernel and mamba block, you can navigate to tests/ops/test_selective_scan_var_len.py and tests/ops/test_mamba_cu_seqlens_equivalence.py and run below script:
python tests/ops/test_mamba_cu_seqlens_equivalence.py
pytest tests/ops/test_selective_scan_var_len.py
Also, I found an awesome and easy-to-understand figure that can help to visualize the basic idea of mamba and from this picture you might understand why simply resetting A bar at sequences boundaries would work for this PR's proposed varlen mamba.
Figure cited from https://github.com/jzhang38/LongMamba?tab=readme-ov-file#preliminary-studies
Mamba can be seen as some type of mathematical transformation, which is discrete (sub-figure a)->continuous (sub-figure b)->continuous (sub-figure c)->discrete (sub-figure d). We do Zero-Order-Hold(ZOH) to get sub-figure b from sub-figure a. And then from sub-figure b to sub-figure c, we learn and keep track of the coefficient of Legendre polynomials. And finally sub-figure d is what selective scan does.
Therefore, by resetting A bar at sequence boundaries, it can be visualized as if delta t at sequence boundaries is infinite. And it is clear to understand that, from the continuous-time domain, sentences will not affect each other when they are packed and processed together since there is an infinite delta t between sentences to keep them apart.
Some useful links:
- https://blog.premai.io/s4-and-mamba/
- https://srush.github.io/annotated-mamba/hard.html
Update from 2024/07/22:
-
I have migrated to tridao's latest implementation of variable length causal_conv1d ( which requires causal-conv1d>=1.4.0) in this commit. It is awesome that all the variable length features in mamba are powered by CUDA kernels natively. Much faster!
-
Exactly the unified API with mamba2 and flash-attn. (mamba, mamba2, and flash-attn all use
cu_seqlensto power variable length training) Much easier to use! -
The unit test shows that the variable length mamba block has exact mathematical equality both in the forward pass and backward pass.
python tests/ops/test_mamba_cu_seqlens_equivalence.py
Generate random cu_seqlens = [0, 116, 155, 349, 479, 674, 864, 881, 1024]
max diff for output in varlen_mamba fwd pass: 4.470348358154297e-08
mean diff for output in varlen_mamba fwd pass: 5.5261386577853955e-09
max diff for A_log in varlen_mamba bwd pass: 6.239861249923706e-08
mean diff for A_log in varlen_mamba bwd pass: 5.321690865756068e-10
max diff for D in varlen_mamba bwd pass: 6.318092346191406e-06
mean diff for D in varlen_mamba bwd pass: 6.176169335958548e-07
max diff for in_proj.weight in varlen_mamba bwd pass: 1.9073486328125e-05
mean diff for in_proj.weight in varlen_mamba bwd pass: 1.098805341825937e-06
max diff for conv1d.weight in varlen_mamba bwd pass: 5.662441253662109e-06
mean diff for conv1d.weight in varlen_mamba bwd pass: 8.699786349097849e-07
max diff for conv1d.bias in varlen_mamba bwd pass: 1.0013580322265625e-05
mean diff for conv1d.bias in varlen_mamba bwd pass: 1.4602501323679462e-06
max diff for x_proj.weight in varlen_mamba bwd pass: 3.6954879760742188e-06
mean diff for x_proj.weight in varlen_mamba bwd pass: 2.984411295869904e-08
max diff for dt_proj.weight in varlen_mamba bwd pass: 8.731149137020111e-09
mean diff for dt_proj.weight in varlen_mamba bwd pass: 3.4094516099258954e-10
max diff for dt_proj.bias in varlen_mamba bwd pass: 2.60770320892334e-08
mean diff for dt_proj.bias in varlen_mamba bwd pass: 2.458180992093162e-09
max diff for out_proj.weight in varlen_mamba bwd pass: 5.7220458984375e-06
mean diff for out_proj.weight in varlen_mamba bwd pass: 2.629302855439164e-07
pytest tests/
============================= test session starts ==============================
platform linux -- Python 3.10.14, pytest-8.3.1, pluggy-1.5.0
rootdir: /dev/varlen_mamba
plugins: typeguard-3.0.2
collected 392 items
tests/ops/test_selective_scan.py .................... [ 5%]
tests/ops/test_selective_scan_var_len.py ............ [ 8%]
tests/ops/triton/test_layernorm_gated.py .s.s.s.s.s.s.....s.s.s.s.s.s... [ 16%]
..s.s.s.s.s.s.....s.s.s.s.s.s.... [ 24%]
tests/ops/triton/test_selective_state_update.py ........................ [ 30%]
........................................................................ [ 48%]
........................................................................ [ 67%]
........................................................................ [ 85%]
.............................. [ 93%]
tests/ops/triton/test_ssd.py ........................ [ 99%]
tests/test_generation.py .. [100%]
================= 368 passed, 24 skipped in 256.74s (0:04:16) ==================
Dear authors,
@tridao @albertfgu Firstly, thanks for the awesome work on theoretical analysis and code development of mamba, mamba2, and other series of state space models!
Currently, many users (https://github.com/state-spaces/mamba/issues/356, https://github.com/state-spaces/mamba/issues/236, https://github.com/state-spaces/mamba/issues/180) expect mamba to natively support variable-length training (just like what flash-attn and mamba2 have done) to utilize hardware efficiency, so we tried to develop this feature.
In this PR:
(1) We provide the unified API interface with mamba2 and flash-attn to support variable-length training. (via cu_seqlens)
(2) Variable length mamba is natively powered by causal_conv1d and selective scan CUDA kernels.
So, could this PR would be reviewed and merged as a feature for mamba if possible? Thanks!
It's great to see that there already one paper/project (Is Mamba Compatible with Trajectory Optimization in Offline Reinforcement Learning, NeurIPS'24) adopting our code in the area of offline Reinforcement Learning.
Link: https://arxiv.org/pdf/2405.12094
Hi @zigzagcai, thank you for the great work! I tried to install your version but encountered the selective_scan_cuda.cpython-310-x86_64-linux-gnu.so: undefined symbol problem. Does it also occur to you when you test the code?
The full pipeline I did is the following:
# (optionally) clone causal-conv1d, also tried pip install causal-conv1d==1.4.0
git clone https://github.com/Dao-AILab/causal-conv1d
cd causal-conv1d
git checkout v1.4.0
pip install -e .
cd ..
# clone and checkout your pr
git clone https://github.com/state-spaces/mamba
cd mamba
git fetch origin pull/244/head:pr-244
git checkout pr-244
pip install -e .
Tried installing with pytorch 2.4, 2.1, cuda 12.5, 12.1. All settings have the same problem:
> python tests/ops/test_mamba_cu_seqlens_equivalence.py
Traceback (most recent call last):
File "/.../mamba/tests/ops/test_mamba_cu_seqlens_equivalence.py", line 5, in <module>
from mamba_ssm.modules.mamba_simple import Mamba
File "/usr/local/lib/python3.10/dist-packages/mamba_ssm/__init__.py", line 3, in <module>
from mamba_ssm.ops.selective_scan_interface import selective_scan_fn, mamba_inner_fn
File "/usr/local/lib/python3.10/dist-packages/mamba_ssm/ops/selective_scan_interface.py", line 16, in <module>
import selective_scan_cuda
ImportError: /usr/local/lib/python3.10/dist-packages/selective_scan_cuda.cpython-310-x86_64-linux-gnu.so: undefined symbol: _ZN2at4_ops10zeros_like4callERKNS_6TensorESt8optionalIN3c1010ScalarTypeEES5_INS6_6LayoutEES5_INS6_6DeviceEES5_IbES5_INS6_12MemoryFormatEE
Additionally, I also found that the installed causal-conv1d and mamba-ssm doesn't seem to recognize each other, because when I do the following, it shows that causal-conv1d is required by nothing:
>pip show causal-conv1d
Name: causal-conv1d
Version: 1.4.0
Summary: Causal depthwise conv1d in CUDA, with a PyTorch interface
Home-page: https://github.com/Dao-AILab/causal-conv1d
Author: Tri Dao
Author-email: [email protected]
License:
Location: /usr/local/lib/python3.10/dist-packages
Requires: ninja, packaging, torch
Required-by: (empty here)
Similarly, mamba_ssm does not require causal-conv1d:
> pip show mamba-ssm
Name: mamba_ssm
Version: 2.2.2
Summary: Mamba state-space model
Home-page:
Author:
Author-email: Tri Dao <[email protected]>, Albert Gu <[email protected]>
...
Location: /usr/local/lib/python3.10/dist-packages
Requires: einops, ninja, packaging, setuptools, torch, transformers, triton (causal-conv1d is not here)
Required-by:
If this issue does't occur to you, could you provide the installing script you are using for the most up-to-date version? Thanks!
Hi @JindongJiang ,
I share my minimum reproducing steps here.
- The hardware and software info:
HW: A800/A100
Driver: CUDA 11.8
- Steps to setup envs:
conda create -n mamba_dev python=3.10
conda activate mamba_dev
pip install torch==2.1.0 torchvision==0.16.0 torchaudio==2.1.0 --index-url https://download.pytorch.org/whl/cu118
pip install causal-conv1d==1.4.0
pip install einops huggingface-hub transformers triton pytest
git clone https://github.com/zigzagcai/varlen_mamba.git --branch feat/add-cu_seqlens
cd varlen_mamba
pip install -e .
- Run tests:
pytest tests/
============================= test session starts ==============================
platform linux -- Python 3.10.14, pytest-8.3.2, pluggy-1.5.0
rootdir: /blahblah/zigzagcai/varlen_mamba
plugins: typeguard-3.0.2
collected 392 items
tests/ops/test_selective_scan.py .................... [ 5%]
tests/ops/test_selective_scan_var_len.py ............ [ 8%]
tests/ops/triton/test_layernorm_gated.py .s.s.s.s.s.s.....s.s.s.s.s.s... [ 16%]
..s.s.s.s.s.s.....s.s.s.s.s.s.... [ 24%]
tests/ops/triton/test_selective_state_update.py ........................ [ 30%]
........................................................................ [ 48%]
........................................................................ [ 67%]
........................................................................ [ 85%]
.............................. [ 93%]
tests/ops/triton/test_ssd.py ........................ [ 99%]
tests/test_generation.py .. [100%]
================= 368 passed, 24 skipped in 183.78s (0:03:03) ==================
python tests/ops/test_mamba_cu_seqlens_equivalence.py
Generate random cu_seqlens = [0, 5, 84, 182, 202, 284, 796, 836, 1024]
max diff for output in varlen_mamba fwd pass: 6.407499313354492e-07
mean diff for output in varlen_mamba fwd pass: 3.794611203034037e-08
max diff for A_log in varlen_mamba bwd pass: 6.705522537231445e-08
mean diff for A_log in varlen_mamba bwd pass: 6.687657094772703e-10
max diff for D in varlen_mamba bwd pass: 4.76837158203125e-06
mean diff for D in varlen_mamba bwd pass: 6.003104999763309e-07
max diff for in_proj.weight in varlen_mamba bwd pass: 1.9073486328125e-05
mean diff for in_proj.weight in varlen_mamba bwd pass: 1.0953947366942884e-06
max diff for conv1d.weight in varlen_mamba bwd pass: 5.364418029785156e-06
mean diff for conv1d.weight in varlen_mamba bwd pass: 8.792806056590052e-07
max diff for conv1d.bias in varlen_mamba bwd pass: 7.867813110351562e-06
mean diff for conv1d.bias in varlen_mamba bwd pass: 1.4787228792556562e-06
max diff for x_proj.weight in varlen_mamba bwd pass: 5.029141902923584e-06
mean diff for x_proj.weight in varlen_mamba bwd pass: 3.1919995535645285e-08
max diff for dt_proj.weight in varlen_mamba bwd pass: 1.3300450518727303e-08
mean diff for dt_proj.weight in varlen_mamba bwd pass: 3.616623112101536e-10
max diff for dt_proj.bias in varlen_mamba bwd pass: 3.166496753692627e-08
mean diff for dt_proj.bias in varlen_mamba bwd pass: 2.6783406603669846e-09
max diff for out_proj.weight in varlen_mamba bwd pass: 6.67572021484375e-06
mean diff for out_proj.weight in varlen_mamba bwd pass: 2.693569740586099e-07

