fms-fsdp icon indicating copy to clipboard operation
fms-fsdp copied to clipboard

maximize mistral throughput

Open aldopareja opened this issue 1 year ago • 2 comments

Instructlab backend currently focuses on mistral fine tuning and I'm trying to maximize throughput for that. If anyone notices anything obvious or has any suggestions I'd truly appreciate it. @raghukiran1224 mentioned that posting an issue here would potentially help.

I'm currently seeing a throughput of around 90 samples per second at max context length of 2600 tokens (but on average is only around 500 tokens) on 80 GPUs in prod vela. On a single node I get a throughput of around 11.2 samples per second and the best way is to do shard_op (zero stage 2) and no gradient checkpointing.

The main bottleneck is the networking, so having the largest possible batch size maximizes throughput since the networking communication bottlenecks almost at the same rate regardless of the bs. For such reason I ended up using HYBRID_SHARD_ZERO2 and enabling checkpointing to get a bs of 20 samples per gpu at 2600 max length.

These are the main parts to look at:

Model setup

Currently using HYBRID_SHARD_ZERO2 but have experimented with all the possibilities. Couldn't get torch.compile to work. And had to enable gradient checkpointing to maximize batch size.

def setup_model(model_name, tokenizer):
    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        attn_implementation="flash_attention_2",
        torch_dtype=torch.bfloat16,
    )
    if len(tokenizer) > model.config.vocab_size:
        print(
            f"WARNING: tokenizer has {len(tokenizer)} tokens but model has {model.config.vocab_size} vocab size"
        )
        model.resize_token_embeddings(
            int(8 * math.ceil(len(tokenizer) / 8.0))
        )  # make the vocab size multiple of 8 for sharding the embedding layer.

    assert model.__class__.__name__ in [
        "MistralForCausalLM"
    ], f"Model class name: {model.__class__.__name__} is not supported."

    model = FSDP(
        model,
        auto_wrap_policy=partial(
            transformer_auto_wrap_policy,
            transformer_layer_cls={
                MistralDecoderLayer,
            },
        ),
        # use_orig_params=True,
        limit_all_gathers=True,
        mixed_precision=MixedPrecision(
            param_dtype=torch.bfloat16,
            reduce_dtype=torch.bfloat16,
            buffer_dtype=torch.bfloat16,
        ),
        backward_prefetch=BackwardPrefetch.BACKWARD_PRE,
        sharding_strategy=ShardingStrategy._HYBRID_SHARD_ZERO2,
        device_id=torch.cuda.current_device(),
    )
    model.gradient_checkpointing_enable()
    # model = torch.compile(model)
    return model

training loop

importantly the use_cache=False, even though it is commented out gets set to True because only the gradient checkpointing works.

        for batch in train_loader:
            start = time.time()
            for k in batch:
                batch[k] = batch[k].to(local_rank)

            output = model(
                **batch,
                # use_cache=False,
            )

            loss = output["loss"]
            loss.backward()

            if global_step % args.gradient_accumulation_steps == 0:
                optimizer.step()
                scheduler.step()
                optimizer.zero_grad()

aldopareja avatar Mar 25 '24 02:03 aldopareja

Throughput test

Performed many experiments in prod cluster trying to get the best throughput using pytorch'sfsdp. In general, when the number of nodes increases above certain threshold the best throughput is handled by the HYBRID_SHARD approach, which does FULL_SHARD on a single node but each node has a full copy of the model, so most of the all_gather operations happen inside a node, with much faster networking than intra-node communication.

Also, it's not advisable to just increase the batch size until you run out of memory and then lowering just a bit because CUDA MALLOC RETRIES goes up and creates bottlenecks. it's important to find the maximum batch size that gets CUDA MALLOC RETRIES to at most 1 or 2. (which you get with torch.cuda.memory_summary()) even if the batch size is lower. Throughput is highest there.

These are the results of the experiments I did to test:

LEN 4600

FULL_SHARD and PRE and GRAD_CKPT:
  • SUMMARY: batch_size=10, time=73.66316413879395, throughput=6.516146032576699
  • SUMMARY: batch_size=12, time=70.2414870262146, throughput=6.833564184291811
  • SUMMARY: batch_size=14, time=91.22835445404053, throughput=4.910753645871884
SHARD_OP and PRE and GRAD_CKPT:
  • SUMMARY: batch_size=8, time=70.51107478141785, throughput=6.353610642851947
  • SUMMARY: batch_size=10, time=73.39471912384033, throughput=6.539978366153278
  • SUMMARY: batch_size=12, time=67.05956053733826, throughput=5.726252529529654
SHARD_GRAD_OP and PRE:
  • SUMMARY: batch_size=1, time=60.53613209724426, throughput=6.211164907227905
FULL_SHARD and PRE
  • SUMMARY: batch_size=1, time=60.23241829872131, throughput=5.976845984172338
FULL_SHARD and POST -- 3 NODES
  • SUMMARY: batch_size=2, time=62.37387943267822, throughput=9.234633320318665
HYBRID_ZERO_2 and POST and GRAD_CKPT -- 3 NODES
  • batch size 20 OOM
  • SUMMARY: batch_size=12, time=74.82603144645691, throughput=19.244634527756094
HYBRID_SHARD and POST and GRAD_CKPT -- 3 NODES
  • SUMMARY: batch_size=12, time=74.99147772789001, throughput=19.20217749301761
  • SUMMARY: batch_size=14, time=67.83051562309265, throughput=19.814084945545247
  • bs 16 OOM
SHARD_GRAD_OP and POST and GRAD_CKPT -- 3 NODES
  • SUMMARY: batch_size=12, time=74.90133571624756, throughput=19.22528857416179
  • SUMMARY: batch_size=14, time=64.9182357788086, throughput=20.70296024465901, MALLOC RETRIES: 2
FULL_SHARD and POST and GRAD_CKPT -- 3 NODES
  • SUMMARY: batch_size=14, time=65.47520995140076, throughput=20.526851280415173

CPU_OFFLOAD and POST and GRAD_CKPT -- 3 NODES

  • SUMMARY: batch_size=16, time=78.79792857170105, throughput=19.492884103647853, MALLOC: 1

len 2048

FULL_SHARD, POST, GRAD_CKPT -- 3 NODES
  • SUMMARY: batch_size=16, time=70.0143084526062, throughput=38.39213775336893
  • SUMMARY: batch_size=28, time=75.14857411384583, throughput=44.71141527538627
  • SUMMARY: batch_size=30, time=60.386337995529175, throughput=47.69288964157913, MALLOC RETRIES: 1
  • SUMMARY: batch_size=32, time=65.69702291488647, throughput=46.76009462050683

len 2600

SHARD_GRAD_OP, POST, GRAD_CKPT --5 NODES
  • SUMMARY: batch_size=6, time=64.64012384414673, throughput=44.554355293489145
  • SUMMARY: batch_size=14, time=60.21819519996643, throughput=55.797070493155175
  • SUMMARY: batch_size=16, time=67.74451446533203, throughput=56.683551715964114
  • SUMMARY: batch_size=18, time=60.19274282455444, throughput=59.80786914442937
  • SUMMARY: batch_size=20, time=67.63811206817627, throughput=59.13825160843786
FULL_SHARD, POST, GRAD_CKPT -- 5 NODES
  • SUMMARY: batch_size=18, time=65.28372812271118, throughput=55.14389688271886
  • SUMMARY: batch_size=20, time=71.572922706604, throughput=55.88704530451299
  • SUMMARY: batch_size=22, time=77.91983580589294, throughput=56.468280244586595
HYBRID (ZERO 3), PRE, GRAD_CKPT -- 5 NODES
  • SUMMARY: batch_size=12, time=66.17386412620544, throughput=50.7753128916656
  • SUMMARY: batch_size=20, time=69.58629775047302, throughput=57.48257119667158
  • SUMMARY: batch_size=22, time=76.5053014755249, throughput=57.51233076487579
HYBRID (ZERO 3), POST, GRAD_CKPT -- 5 NODES
  • SUMMARY: batch_size=22, time=76.11831855773926, throughput=57.80473041246886
  • SUMMARY: batch_size=24, time=62.483471632003784, throughput=61.45617894676771

len 2600 cache=True

SHARD_GRAD_OP, POST
  • batch_size=2, throughput=11.100
SHARD_GRAD_OP, PRE
  • batch_size=2, throughput=11.45
FULL_SHARD, PRE
  • batch_size=2, throughput=

aldopareja avatar Mar 25 '24 02:03 aldopareja

Hi, I've been working with @sahilsuneja1 on getting more throughput our of Mixtral for speculator training. I believe it'll also be useful for InstructLab testing, given the code you posted. Feel free to open a chat with the both of us if this is still relevant.

ani300 avatar May 06 '24 14:05 ani300