DeepSpeed icon indicating copy to clipboard operation
DeepSpeed copied to clipboard

[BUG] InferenceEngine produces the different outputs with the default torch model even inputing the same

Open xshaun opened this issue 2 years ago • 1 comments

Describe the bug The inference engine (init_inference) produces the different outputs with the original torch model even we input the same one torch tensor.

To Reproduce

import torch
import deepspeed
from transformers import AutoModelForCausalLM


input_ids = torch.randint(0, 2048, (1, 1024), dtype=torch.int64).cuda()
opt125m = AutoModelForCausalLM.from_pretrained(
    "facebook/opt-125m", torch_dtype=torch.float16
).cuda()

"""
huggingface
"""
_hg_output = opt125m(input_ids)
print("model size:", sum(p.numel() * p.element_size() for p in opt125m.parameters()))
# model size: 250475520


"""
init_inference
"""
model = deepspeed.init_inference(
    model=opt125m,
    tensor_parallel={"tp_size": 1},
    dtype=torch.float16,
    zero={"stage": 0},
    quant={"enabled": False},
    replace_with_kernel_inject=False,
    enable_cuda_graph=False,
).cuda()

print("model size:", sum(p.numel() * p.element_size() for p in model.parameters()))
# model size: 250475520
_ds_output = opt125m(input_ids)

for (_hg_k, _hg_o), (_ds_k, _ds_o) in zip(_hg_output.items(), _ds_output.items()):
    assert torch.allclose(_hg_o, _ds_o, 0, 0), f"{_hg_k}: {_hg_o}, {_ds_k}: {_ds_o}"

Expected behavior The deepspeed inference should not change the values of model parameters and outputs if no explicit related setting (such as quantization)

The deepspeed should keep the outputs always equal to that of the original models if no related optimizations introduced (for example, the optimization will sacrifice the model accuracy). keeping the same behavior with original model.

ds_report output

--------------------------------------------------
DeepSpeed C++/CUDA extension op report
--------------------------------------------------
NOTE: Ops not installed will be just-in-time (JIT) compiled at
      runtime if needed. Op compatibility means that your system
      meet the required dependencies to JIT install the op.
--------------------------------------------------
JIT compiled ops requires ninja
ninja .................. [OKAY]
--------------------------------------------------
op name ................ installed .. compatible
--------------------------------------------------
 [WARNING]  async_io requires the dev libaio .so object and headers but these were not found.
 [WARNING]  async_io: please install the libaio-dev package with apt
 [WARNING]  If libaio is already installed (perhaps from source), try setting the CFLAGS and LDFLAGS environment variables to where it can be found.
async_io ............... [NO] ....... [NO]
cpu_adagrad ............ [NO] ....... [OKAY]
cpu_adam ............... [NO] ....... [OKAY]
fused_adam ............. [NO] ....... [OKAY]
fused_lamb ............. [NO] ....... [OKAY]
quantizer .............. [NO] ....... [OKAY]
random_ltd ............. [NO] ....... [OKAY]
 [WARNING]  sparse_attn requires a torch version >= 1.5 but detected 2.0
 [WARNING]  using untested triton version (2.0.0), only 1.0.0 is known to be compatible
sparse_attn ............ [NO] ....... [NO]
spatial_inference ...... [NO] ....... [OKAY]
transformer ............ [NO] ....... [OKAY]
stochastic_transformer . [NO] ....... [OKAY]
transformer_inference .. [NO] ....... [OKAY]
utils .................. [NO] ....... [OKAY]
--------------------------------------------------
DeepSpeed general environment info:
torch install path ...............  /python3.8/site-packages/torch']
torch version .................... 2.0.0+cu117
deepspeed install path ...........  /python3.8/site-packages/deepspeed']
deepspeed info ................... 0.8.1, unknown, unknown
torch cuda version ............... 11.7
torch hip version ................ None
nvcc version ..................... 11.6
deepspeed wheel compiled w. ...... torch 1.12, cuda 11.6

Screenshots Screen Shot 2023-04-18 at 16 03 29

System info (please complete the following information):

  • OS: Ubuntu 20.04
  • A100, GTX

xshaun avatar Apr 18 '23 08:04 xshaun

Hi @xshaun, the difference in logits is because DeepSpeed is replacing some modules in the model for tensor parallelism. However this isn't necessary if you are running with only 1 GPU so I've created a fix for that. https://github.com/microsoft/DeepSpeed/pull/3449

molly-smith avatar May 04 '23 23:05 molly-smith

Fix PR has been merged. Closing issue. Please reopen if issue persists.

molly-smith avatar May 11 '23 19:05 molly-smith