TransformerEngine icon indicating copy to clipboard operation
TransformerEngine copied to clipboard

tp_overlap need tensor parallel is equal world size ?

Open kuangdao opened this issue 1 year ago • 5 comments

i want set tp size = 2 and the global world size = 2

the code is :


import os
import sys
import subprocess
import argparse

import torch
import torch.distributed as dist

import transformer_engine.pytorch as te
from transformer_engine.common.recipe import Format, DelayedScaling


def parse_args(argv=None, namespace=None):
    parser = argparse.ArgumentParser(
        description="Test a te.LayerNormMLP module with GEMM+comm overlap via Userbuffers."
    )
    parser.add_argument(
        "-i", "--num-iters", type=int, default=5, help="Number of dummy 'training' iterations."
    )
    parser.add_argument("-b", "--batch-size", type=int, default=2, help="Input batch size.")
    parser.add_argument("-s", "--seq-length", type=int, default=2048, help="Input sequence length.")
    parser.add_argument(
        "-n", "--num-heads", type=int, default=64, help="Number of attention heads."
    )
    parser.add_argument(
        "-d", "--head-dim", type=int, default=128, help="Dimension of each attention head."
    )
    parser.add_argument(
        "--mlp-expansion-factor",
        type=int,
        default=4,
        help="MLP block intermediate size as a factor of hidden dimension.",
    )
    parser.add_argument("--seed", type=int, default=1234, help="RNG seed.")
    parser.add_argument(
        "--fp8", action="store_true", default=False, help="Enables the te.fp8_autocast() context."
    )
    parser.add_argument(
        "--no-comm-overlap",
        action="store_true",
        default=False,
        help="Disable the comm+GEMM overlap.",
    )
    parser.add_argument("-v", "--verbose", action="store_true", default=False)
    return parser.parse_args(argv, namespace)


def train(opts):
    WORLD_RANK = int(os.getenv("RANK"))
    WORLD_SIZE = int(os.getenv("WORLD_SIZE"))

    def dist_print(msg, end="\n", all_ranks=False):
        if WORLD_RANK == 0 or all_ranks:
            print(f"[RANK-{WORLD_RANK}] {msg}", end=end)


    torch.cuda.set_device(WORLD_RANK)
    torch.manual_seed(opts.seed + WORLD_RANK)
    torch.cuda.manual_seed(opts.seed + WORLD_RANK)

    dist.init_process_group(
        backend="nccl",
        rank=WORLD_RANK,
        world_size=WORLD_SIZE,
        device_id=torch.device(f"cuda:{WORLD_RANK}"),
    )
    
    

    tp_group_0 = dist.new_group([0, 1],backend="nccl")
    tp_group_1 = dist.new_group([2, 3],backend="nccl")
    tp_group_2 = dist.new_group([4, 5],backend="nccl")
    tp_group_3 = dist.new_group([6, 7],backend="nccl")

    if WORLD_RANK in [0, 1]:
        tp_group = tp_group_0
    elif WORLD_RANK in [2, 3]:
        tp_group = tp_group_1
    elif WORLD_RANK in [4, 5]:
        tp_group = tp_group_2
    elif WORLD_RANK in [6, 7]:
        tp_group = tp_group_3

    tensor = torch.ones([2, 2]).cuda() * WORLD_RANK
    dist.all_reduce(tensor, op=dist.ReduceOp.SUM, group=tp_group)

    print("after allreduce is : {}".format(tensor))


    tp_size = dist.get_world_size(tp_group)


    ag_cfg = {  
        "method": "ring_exchange",
        "num_splits": 8,
        "num_sm": 1,
        "set_sm_margin": False,
    }
    rs_cfg = {  
        "method": "ring_exchange",
        "num_splits": 4,
        "num_sm": 1,
        "set_sm_margin": True,
    }
    hidden_size = opts.num_heads * opts.head_dim
    batched_size = opts.seq_length * opts.batch_size

    print("batched_size is : {}".format(batched_size))

    if not opts.no_comm_overlap:
        te.initialize_ub(
            [batched_size, hidden_size],
            tp_group,
            use_fp8=opts.fp8,
            dtype=torch.bfloat16,
            ub_cfgs={
                "fc1_fprop": ag_cfg,
                "fc1_dgrad": rs_cfg,
                "fc2_fprop": rs_cfg,
                "fc2_dgrad": ag_cfg,
            },
        )

    
    model = te.LayerNormMLP(
        hidden_size,
        opts.mlp_expansion_factor * hidden_size,
        params_dtype=torch.bfloat16,
        device="cuda",
        tp_group=tp_group,
        tp_size=tp_size,
        set_parallel_mode=True,
        sequence_parallel=True,  
        seq_length=opts.seq_length,
        micro_batch_size=opts.batch_size,
        ub_overlap_rs_dgrad=not opts.no_comm_overlap,
        ub_overlap_rs=not opts.no_comm_overlap,
        ub_overlap_ag=not opts.no_comm_overlap,
    )

    optim = torch.optim.Adam(model.parameters(), lr=0.0001)

    fp8_format = Format.HYBRID
    fp8_recipe = DelayedScaling(fp8_format=fp8_format, amax_history_len=32, amax_compute_algo="max")

    for i in range(opts.num_iters):
        dist_print(f"Iter {i+1}", all_ranks=opts.verbose)

        dist_print("|-- Generate random input batch", all_ranks=opts.verbose)
        x = torch.rand(
            (opts.seq_length // tp_size, opts.batch_size, hidden_size),
            dtype=torch.bfloat16,
            device="cuda",
            requires_grad=True,
        )

        dist_print("|-- Forward pass", all_ranks=opts.verbose)
        with te.fp8_autocast(enabled=opts.fp8, fp8_recipe=fp8_recipe, fp8_group=tp_group):
            y = model(x)
            dist_print("|-- Compute loss", all_ranks=opts.verbose)
            loss = y.flatten().sum()

        dist_print("|-- Backward pass", all_ranks=opts.verbose)
        loss.backward()

        dist_print("|-- Optimizer step", all_ranks=opts.verbose)
        optim.step()

    te.destroy_ub()
    dist.destroy_process_group()


if __name__ == "__main__":
    if "TORCHELASTIC_RUN_ID" in os.environ.keys():
        args = parse_args()
        train(args)
    else:
        subprocess.run(
            ["torchrun", f"--nproc-per-node={torch.cuda.device_count()}", *sys.argv],
            env=os.environ,
            check=True,
        )
    os._exit(0)


and i run with torchrun --standalone --nnodes=1 --nproc-per-node=$(nvidia-smi -L | wc -l) te_sub_group.py

the error is :

企业微信截图_f2d656f8-4940-4441-b4f3-066153c1117c

the commit id of TransformerEngine is 4a4f05dadf7032ff2f4c0780d9adcde77878c7b1

and i use the docker image is nvcr.io/nvidia/nemo:24.05

kuangdao avatar Jun 25 '24 12:06 kuangdao

The tensor parallel group can be a subset of the world group. We frequently split the world group into orthogonal tensor-parallel, data-parallel, and pipeline-parallel groups.

Based on the error message, it looks like there's an error when NCCL is initializing IPC communicators: https://github.com/NVIDIA/TransformerEngine/blob/4a4f05dadf7032ff2f4c0780d9adcde77878c7b1/transformer_engine/pytorch/csrc/userbuffers/userbuffers-host.cpp#L501 To get more information, can you set NCCL_DEBUG=WARN in the environment?

timmoon10 avatar Jun 25 '24 22:06 timmoon10

i have set export NCCL_DEBUG=WARN and there is no additional message 企业微信截图_84f6cf9d-47de-46f8-a314-aeac88cf9a0c

kuangdao avatar Jun 26 '24 02:06 kuangdao

@kuangdao TE in general supports TP size < world size, but the comm+GEMM overlap has some unique restrictions. The underlying device-to-device comms code currently assumes TP size == world size. You may be able to get around this limitation by running with UB_SKIPMC=1, but this leverages CUDA IPC Handles instead of CUDA Multicast so it may not be as performant.

As a disclaimer, comm+GEMM overlap is currently an experimental and somewhat fragile feature that is not yet fully supported in TE under all circumstances (and intentionally undocumented). That will change in the near future, as we improve the underlying device-to-device comms code and test it more rigorously on different platforms.

denera avatar Jul 01 '24 15:07 denera

thanks, i know, i think comm+GEMM overlap is outstanding job, and i hope more documents such as design and Implementation will be give.

kuangdao avatar Jul 02 '24 03:07 kuangdao

@kuangdao -- we merged some changes to comm+GEMM overlap in the last month specifically to address multi-node mixed DP/TP use-cases. This feature is still restricted to tp_size <= local_size where local_size is the # of GPUs in a single NVLink domain (currently a single physical node of max 8 GPUs), but it now functions correctly with model replication across node boundaries. Could you test again and confirm if this works for your use case?

denera avatar Aug 16 '24 20:08 denera