How should I modify the code to implement Sequence Parallel for training Multimodal models?
Hi! I'm using XTuner to finetune Multimodal LLMs on long-context data, so I need to use sequence parallel to avoid OOM. I read the code about sequence parallelism, but I have poor knowledge about these techniques.
I know I need to implement dispatched forward functions for the corresponding modules like xtuner/model/modules/dispatch/qwen2.py, and register these functions in corresponding DISPATCH_MAPPING in xtuner/model/modules/dispatch/__init__.py. But I can't figure out whether I need to change other parts of the code. Multimodal LLMs introduces an additional stage of scattering image features into corresponding positions in the input sequence, does this process affect Sequence Parallel?
I would appreciate it if you could offer some help! Thank you!