petals icon indicating copy to clipboard operation
petals copied to clipboard

Inference issues on Volta-based swarm

Open barsuna opened this issue 3 years ago • 6 comments

Hello folks,

Trying to run a private swarm on a 7x Volta-generation GPUs. As suggested by docs, i've set torch_dtype to float16 and NUM_BLOCKS to 10 (these are 32GB GPUs) and removed load-8-bit argument. All 7 are running on the same linux host.

Swarm starts ok and loads all the model blocks, but many of included tests fail.

Trying even the most basic generation seems to always generate the same token (0 == UNK)

Are older GPUs even supported? There are some notes in documentation on what to set for pre-Turing, but the arxiv paper says the server needs to have Turing or later gen GPU.

If older GPUs are supported, do i also need to specify the torch_dtype to be float16 on instantiating model? (i get RuntimeError: "LayerNormKernelImpl" not implemented for 'Half' when running .generate() in this case)

It is torch 1.12.1+cu113 on cuda 11.3

This is what i get as tests:

tests/test_aux_functions.py::test_throughput_basic FAILED <-- this system is behind proxy, i think this is expected tests/test_block_exact_match.py::test_remote_block_exact_match FAILED tests/test_chained_calls.py::test_forward_backward_exact_match FAILED tests/test_chained_calls.py::test_chained_inference_exact_match FAILED tests/test_full_model.py::test_full_model_exact_match[True] FAILED tests/test_full_model.py::test_full_model_exact_match[False] FAILED tests/test_full_model.py::test_greedy_generation PASSED tests/test_full_model.py::test_sampling[sampling_options0] SKIPPED (Sampling is currently not consistent with outputs from Transformers) tests/test_full_model.py::test_sampling[sampling_options1] SKIPPED (Sampling is currently not consistent with outputs from Transformers) tests/test_full_model.py::test_sampling[sampling_options2] SKIPPED (Sampling is currently not consistent with outputs from Transformers) tests/test_full_model.py::test_sampling[sampling_options3] SKIPPED (Sampling is currently not consistent with outputs from Transformers) tests/test_full_model.py::test_beam_search_generation FAILED tests/test_linear8bitlt.py::test_layout_exact_match SKIPPED (this test requires a turing-generation or newer GPU, see bitsandbytes docs) tests/test_linear8bitlt.py::test_linear_exact_match SKIPPED (this test requires a turing-generation or newer GPU, see bitsandbytes docs) tests/test_linear8bitlt.py::test_linear_no_igemmlt PASSED tests/test_priority_pool.py::test_priority_pools PASSED tests/test_remote_sequential.py::test_remote_sequential FAILED tests/test_remote_sequential.py::test_remote_sequential_prompts FAILED

The greedy search test seems to pass, but i'm suspicious... could it be an issue with a test?

Here is what i see from .generate():

model = DistributedBloomForCausalLM.from_pretrained("bigscience/bloom-petals", initial_peers=INITIAL_PEERS) inputs = tokenizer("Cat sat on", return_tensors="pt")["input_ids"] outputs = model.generate(inputs, max_new_tokens=5) print(tokenizer.decode(outputs[0])) Cat sat on type(outputs), outputs.shape, outputs (torch.Tensor, torch.Size([1, 8]), tensor([[40171, 13770, 664, 0, 0, 0, 0, 0]]))

If needed, can provide the log from tests.

barsuna avatar Dec 28 '22 20:12 barsuna

Hi! Sorry, we're a bit overwhelmed r/n, will respond within 48 hours

justheuristic avatar Dec 30 '22 15:12 justheuristic

TL:DR Older GPUs are supported, but most files in ./tests assume that you are running on CPU

Why do tests run on CPU?: These tests are written to ensure that the code is correct. They run in an isolated environment provided by github actions. That environment does not have GPU.

You will get minor differences in output logits, same as with all other transformers models. This is fine for most applications, unless you need exact equality up to bits.

Should I specify torch_dtype==float16?: We recommend that you do NOT specify that. The original bloom model was trained in bfloat16. Switching it to foat16 may turn out fine, but there may be some side-effects. Namely, some training workloads may fail in float16.

Why do you get "not implemented for Half" errors?: the code that caused error is client-side modules (layernorm after embeddings) -- and the error only shows up in float16 on cpu. If you want to fix this error while still in float16, run it on GPU. You may also cast to float32 or bfloat16 on cpu, but CPU-based float32 takes up more memory, while cpu-baed bfloat16 is fairly slow.

justheuristic avatar Dec 30 '22 15:12 justheuristic

Thank you for your reply @justheuristic, this helps! I'm not concerned about tests, just thought to use them to see whats wrong with my setup. Knowing this is generally not unsupported setup, i've done some more debugging to narrow down inference issue. This is what i found:

When calling generate, as in

outputs = model.generate(inputs, max_new_tokens=5)

What i see happens inside .generate in line 180 in remote_generation.py

hidden_state = session.step(embs, prompts=intermediate_prompts, hypo_ids=hypo_ids)[:, -1]

Produces hidden_state which is all nans, as in (i've added some debug to print the shape/contents)

DEBUG(.generate) hidden_state1 shape torch.Size([1, 14336]) hidden_state: tensor([[nan, nan, nan, ..., nan, nan, nan]], dtype=torch.bfloat16)

Following inside the .step, i see that 1st 6 servers out 7 return some meaningfully-looking tensor, but last one (hosting blocks 60-70) returns nan's

(i'm talking about line 297 in inference_session.py)

outputs = session.step(inputs, prompts[span.start : span.end], **kwargs)

i.e. i see

DEBUG(.step0) server_idx 6 span RemoteSpanInfo(start=60, end=70, peer_id=<libp2p.peer.id.ID (QmfHwyZ2sbq6b7JLXb7mTS4fJmF1SpRMpArZE8E2MSsSce)>) inputs tensor([[[ -0.9604, 2.4609, -2.0430, ..., 1.3594, 3.2793, -3.2305], [ 7.2148, 8.3984, 7.4336, ..., 6.2188, 5.9531, 3.7188], [ -3.0078, 5.2852, 2.6855, ..., 0.3782, -12.8125, -7.4609]]], dtype=torch.float16) outputs tensor([[[nan, nan, nan, ..., nan, nan, nan], [nan, nan, nan, ..., nan, nan, nan], [nan, nan, nan, ..., nan, nan, nan]]], dtype=torch.float16)

I have restarted 7th server, and it loads everything without any errors, but problem is persistent.

Also i have ran with logger set to debug, but see nothing particularly suspicious. Any pointers, how to debug this?

barsuna avatar Jan 02 '23 12:01 barsuna

i have commented out --torch_dtype float16 --compression $COMPRESSION when starting swarm and workers...and the issue is gone. Not sure if it was complete restart of everything, or change of dtype to default (or compression?)

Will need to do some more tests to see if i can reproduce the issue back, but at very least there is a workaround.

barsuna avatar Jan 02 '23 14:01 barsuna

I'm glad that you fixed it :)

It's true that we can't be sure yet, but running a float16 server for a bfloat16 model can indeed result in NaNs. Problem is, float16 will break down if any intermediate value exceeds +-6.5e4, which may well happen in BLOOM-176B attention. To reiterate: this is just a guess.

Speaking of which, what compression were you using?

justheuristic avatar Jan 02 '23 19:01 justheuristic

Thanks @justheuristic!

Regarding compression it was set to 'NONE', since this is all on 1 server.

I'll try to maybe add a check if inputs exceeding +-6.5e4

Performance-wise i'm getting from slightly shy of 3 sec / token to about slightly shy of 9 seconds per token (on >1000 tokens) which includes several restarts... which i might open a separate issue for.

Overall this looks pretty promising, can't wait to try the tuning :)

barsuna avatar Jan 02 '23 20:01 barsuna