DeepSpeed icon indicating copy to clipboard operation
DeepSpeed copied to clipboard

[REQUEST] Switch to a unified process group management framework for DP/TP/PP/EP/SP

Open eternalNight opened this issue 2 months ago • 6 comments

Problem Statement

Today we have multiple sets of ProcessGroup management in the codebase for different parallel scenarios, namely:

  • PipelineModule uses PipelineParallelGrid for pipeline parallelism.
  • AutoTP uses one set of APIs (whose names start with letters) from utils.groups.
  • MoE uses another set of APIs (whose names start with _) from utils.groups.
  • Ulysses SP uses a DeviceMesh created on deepspeed.initialize.

All of those implementations deal with the process group creation (either using DeviceMesh or customized logic) under complex combinations. It significantly increases maintenance burden as well as difficulty in introducing new parallelism techniques.

Besides, that divergence also causes user confusions. deepspeed.initialize() accepts two parameters, i.e., mpu and mesh_param. mpu is for DP/TP (without PP/EP/SP), mesh_param for SP, and for PP/EP parallelism topology should be configured on model initialization. That does not provide a good experience for users who want to try out different configurations for best efficiency.

Proposal

We would like to refactor those ProcessGroup management facilities so that:

  • A unified DeviceMesh based module serves all parallelism strategies and their combinations.
  • Let DeviceMesh instances create and manage ProcessGroups when possible.
  • Simplify the mesh topology configuration interface of deepseed.initialize.
  • (Advanced) Allow different models in the same world to use different parallelism strategies. This is mainly for RL post-training, but need more careful feasibility investigation due to the global map design of ProcessGroup. See https://github.com/volcengine/verl/discussions/897 MCore pain point 1 for more information.
  • (Advanced) Make it easy to extend ZeRO to support HSDP-style replicate + shared parallelism for asymmetric GPU clusters. The extension itself will be tracked in a separate issue.

While supporting multiple dimensions, DeviceMesh does not fit all parallelism techniques and thus need extension. Essentially, DeviceMesh, for each dimension, creates a ProcessGroup among ranks that share the same coordinate except that dimension. It fits the need of DP, TP and SP, but PP and EP have additional requirements: P2P groups among adjacent stages for PP and global data parallel groups for EP. We would like to subclass DeviceMesh in DeepSpeed (e.g. DeepSpeedDeviceMesh?) to collect parallelism configurations from both the model and configurations and create those additional groups. Each model has its own DeepSpeedDeviceMesh instance at DeepSpeedEngine.mesh_device and fetches ProcessGroups from there.

This is an early-stage idea yet. Any comment or suggestion will be welcomed.

eternalNight avatar Nov 07 '25 08:11 eternalNight

@stas00 Following your suggestion, I just created this issue to focus on discussions on the DeviceMesh topic. Please review and feel free to comment.

eternalNight avatar Nov 07 '25 08:11 eternalNight

Thank you for starting the discussion, Junjie. This is indeed something we need to solve.

I'm not quite sure DeviceMesh is the way to standardize on since:

  1. you have no intuitive control over the rank placement - e.g. TP ranks have to be on the same node. I have asked about it the pytorch dev and their answer was that you have to create the mesh so that TP ranks are listed last, so it's sort of not explicit and or so easy to mess up.
  2. Moreover we have situations where groups need to be created before deepspeed.initialize - like UlyssesSP - and DeviceMesh forces you to create all groups at once - it'll fail if you try init just one group not utilizing them all.

On the other hand Megatron-LM's latest incarnation has a full-blown 2K lines parallel-state module (mpu) which gives you all the flexibility you need. https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/parallel_state.py

I think DeviceMesh is a super-nice abstraction for a lay user who needs a quick setup. But for such a complex thing as we have I'd investigate it more. Perhaps using separate components of DeviceMesh could be used, which don't force you to init all ranks at once, but then the neatness is lost. I'm not sure.

I think we need full control and things being very explicit to avoid mistakes that would lead to performance deterioration.

stas00 avatar Nov 07 '25 15:11 stas00

@stas00 Thanks for the comments! That helps me understand the problem better.

Indeed DeviceMesh does not provide flexibility for our usage and we must either extend it or base ourselves on a different framework.

With regard to the APIs of that framework, here's what I think is required:

  • An interface to create the process groups according to a topology definition. Examples include init_device_mesh in PyTorch and initialize_model_parallel in megatron mpu. DeviceMesh assumes the topology should be a grid that applies to the whole model, but that doesn't necessarily hold as attention and MLP blocks may have different strategies. The way megatron mpu uses (i.e., one keyword argument for each parallel dimension) looks more flexible for training.
  • A set of interfaces to get world size, rank of current and sibling processes and the process group of a specific dimension, such as get_group and get_coordinate of DeviceMesh and get_xx series in mpu.

In addition to the common APIs above, DeviceMesh and mpu also have unique features, but I'm not sure if we should include them:

  • mpu provides a series of set_xx APIs for setting world sizes or ranks without touching any created groups. I was wondering what are their primary use cases.
  • DeviceMesh allows slicing, flattening (and possibly reshaping in the future) an existing mesh. That's useful when someone needs to "reshape" a grid for EP, but may not be needed if the EP groups are explicitly set up (like mpu).

Creating process groups in multiple steps (like what UlyssesSP is required today) looks error-prone to me. If later someone would like to explore combining UlyssesSP with other parallelism techniques (say, EP for MoE? not sure if that's ever possible, though) we don't have a place for topology consistency checks. How about something like:

    mpu = initialize_model_parallel(config, ...)
    UlyssesSPAttentionHF.register_with_transformers(
        ...,
        mpu=mpu,
        ...)
    # other initialization here
    model, optimizer, dataloader, lr_scheduler = deepspeed.initialize(model, config=config, mpu=mpu, ...)

As for rank placement, I think that's something we should abstract away from the sight of users. Internally, we can include some heuristics to check if members of the same group are close enough and warn on potential performance degradations. The way mpu designs its interface also helps when the user does not specify the order explicitly.

The primary drawback of mpu is that the module is designed in a singleton object style, thus making it unsuitable in some RL use cases.

To sum up, I would like to investigate further on how to leverage megatron’s parallel_state module elegantly. That module has almost all we need, but still requires extension for UlyssesSP and per-model topology. Previously we copied and modified it in DeepSpeed repo, but that requires non-trivial effort when we want to integrate upstream updates.

eternalNight avatar Nov 12 '25 05:11 eternalNight

mpu provides a series of set_xx APIs for setting world sizes or ranks without touching any created groups. I was wondering what are their primary use cases.

Probably to allow partial group inits - as I exemplified UlyssesSP needs its group init very early (but in reality I think it could be postponed, but would be more restrictive to how it can be integrated). So I'd imagine you'd predefine the sets of ranks, but init them as needed?

like what UlyssesSP is required today) looks error-prone to me

The reason is because in many frameworks DL is dealt with very early on, and Ulysses has its own DL wrapper that it requires groups early. But of course the wrapper could be applied much later, after deepspeed.initialize - I'm actually trying to hack it into HF Transformers and the timing there is very different in HF Trainer than it's in HF Accelerate.

Or as you're suggesting the topology should be set up before deepspeed.initialize - I think this is the best way. Perhaps it could have settings to delay some inits until perhaps they are used if there are such use cases?

As for rank placement, I think that's something we should abstract away from the sight of users

Sure, but they should be able to override it. So ideally the API should be explicit first and then additional API is added to hide complexities away, while not preventing the advanced users with doing it the way they want to do it.

The primary drawback of mpu is that the module is designed in a singleton object style, thus making it unsuitable in some RL use cases.

Could you please describe why it breaks in RL use cases?

To sum up, I would like to investigate further on how to leverage megatron’s parallel_state module elegantly. That module has almost all we need, but still requires extension for UlyssesSP and per-model topology. Previously we copied and modified it in DeepSpeed repo, but that requires non-trivial effort when we want to integrate upstream updates.

Thank you for working on that, Junjie! Feel free to design a better way for UlyssesSP while you're at it - it'd just have to be either backward compatible since we already have it integrated in several frameworks. or we would just create a new API, while keeping the old one working for a year or so.

stas00 avatar Nov 13 '25 07:11 stas00

mpu provides a series of set_xx APIs for setting world sizes or ranks without touching any created groups. I was wondering what are their primary use cases.

Probably to allow partial group inits - as I exemplified UlyssesSP needs its group init very early (but in reality I think it could be postponed, but would be more restrictive to how it can be integrated). So I'd imagine you'd predefine the sets of ranks, but init them as needed?

Partial group creation only works for the last dimension (such as TP). For intermediate dimensions ranks belonging to the same group cannot be determined before world size of lower dimensions are specified. So it still looks odd to me that megatron provides something like set_data_parallel_rank.

Perhaps later when we encounter cases where such set_xx interfaces are needed, we'll have a better idea what they are for and then introduce them into DeepSpeed. But for now, let's focus on initialization and getters.

like what UlyssesSP is required today) looks error-prone to me

The reason is because in many frameworks DL is dealt with very early on, and Ulysses has its own DL wrapper that it requires groups early. But of course the wrapper could be applied much later, after deepspeed.initialize - I'm actually trying to hack it into HF Transformers and the timing there is very different in HF Trainer than it's in HF Accelerate.

Or as you're suggesting the topology should be set up before deepspeed.initialize - I think this is the best way. Perhaps it could have settings to delay some inits until perhaps they are used if there are such use cases?

Probably. There is some complexity because torch.distributed.new_group must be called by all ranks in the world even though some of them may not be part of the new ProcessGroup (ref), but on-demand creation of ProcessGroups for a specific dimension still looks viable.

As for rank placement, I think that's something we should abstract away from the sight of users

Sure, but they should be able to override it. So ideally the API should be explicit first and then additional API is added to hide complexities away, while not preventing the advanced users with doing it the way they want to do it.

The primary drawback of mpu is that the module is designed in a singleton object style, thus making it unsuitable in some RL use cases.

Could you please describe why it breaks in RL use cases?

I should have been more accurate. It does not totally break RL post-training, but can be sub-optimal because all models (actor, critic, reference, reward models) are forced to share the parallelism plan. That's what I learned from https://github.com/volcengine/verl/discussions/897, but I don't have richer experience yet.

eternalNight avatar Nov 14 '25 08:11 eternalNight

Also remember that Deepspeed creates the default group across all gpus, but usually doesn't use it as it then creates new groups which it does use - this wastes probably 0.5GB per gpu - each group uses additional GPU memory, some of which isn't accounted for since the structure lives in nccl which doesn't report to torch.cuda memory allocations, but which can be tracked using nvml gpu memory snapshotting. So it'd be good to eliminate it as well if it's not really needed or use it and not create another DP group if all gpus are DP.

stas00 avatar Nov 14 '25 20:11 stas00