TensorRT-LLM icon indicating copy to clipboard operation
TensorRT-LLM copied to clipboard

V100 w8 llama model Inference failed.

Open YihengBrianWu opened this issue 1 year ago • 4 comments

System Info

GPU: Tesla-V100-SXM2-32GB TRT-LLM Version: v0.9.0 Cuda Version: 12.2 Driver Version: 470.129.06

Who can help?

@byshiue

Information

  • [ ] The official example scripts
  • [X] My own modified scripts

Tasks

  • [X] An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • [ ] My own task or dataset (give details below)

Reproduction

  1. convert checkpoint (success)
  convert_command_args = [
      'python3',
      '/opt/tiger/TensorRT-LLM/examples/llama/convert_checkpoint.py',
      f'--model_dir={model_dir}',
      f'--output_dir={ckpt_output_path}',
      f'--tp_size=1',
      '--use_weight_only'
  ]
  1. Build TRT Engine (success)
  build_command_args = [
      'trtllm-build',
      f'--checkpoint_dir={ckpt_output_path}',
      f'--output_dir={output_path}',
      '--remove_input_padding=enable',
      '--gpt_attention_plugin=float16',
      '--gemm_plugin=float16',
      '--context_fmha=enable',
      '--paged_kv_cache=enable',
      '--use_custom_all_reduce=disable',
  ]
  1. Engine Inference python3 run.py --engine_dir '{output_path}' --max_output_len 100 --tokenizer_dir '{output_path}' --input_text "How do I count to nine in French?"

Expected behavior

Rerurn some words.

actual behavior

libibverbs: Warning: couldn't open config directory '/etc/libibverbs.d'. [TensorRT-LLM] TensorRT-LLM version: 0.9.0 [TensorRT-LLM][INFO] Engine version 0.9.0 found in the config file, assuming engine(s) built by new builder API. [TensorRT-LLM][WARNING] [json.exception.type_error.302] type must be string, but is null [TensorRT-LLM][WARNING] Optional value for parameter kv_cache_quant_algo will not be set. [TensorRT-LLM][WARNING] [json.exception.out_of_range.403] key 'num_medusa_heads' not found [TensorRT-LLM][WARNING] Optional value for parameter num_medusa_heads will not be set. [TensorRT-LLM][WARNING] [json.exception.out_of_range.403] key 'max_draft_len' not found [TensorRT-LLM][WARNING] Optional value for parameter max_draft_len will not be set. [TensorRT-LLM][INFO] MPI size: 1, rank: 0 [TensorRT-LLM][INFO] Loaded engine size: 6286 MiB [TensorRT-LLM][INFO] [MemUsageChange] Init cuBLAS/cuBLASLt: CPU +1, GPU +8, now: CPU 6448, GPU 6613 (MiB) [TensorRT-LLM][INFO] [MemUsageChange] Init cuDNN: CPU +1, GPU +10, now: CPU 6449, GPU 6623 (MiB) [TensorRT-LLM][WARNING] TensorRT was linked against cuDNN 8.9.6 but loaded cuDNN 8.9.2 [TensorRT-LLM][INFO] [MemUsageChange] TensorRT-managed allocation in engine deserialization: CPU +0, GPU +6282, now: CPU 0, GPU 6282 (MiB) [TensorRT-LLM][INFO] [MemUsageChange] Init cuBLAS/cuBLASLt: CPU +0, GPU +8, now: CPU 6466, GPU 7223 (MiB) [TensorRT-LLM][INFO] [MemUsageChange] Init cuDNN: CPU +0, GPU +8, now: CPU 6466, GPU 7231 (MiB) [TensorRT-LLM][WARNING] TensorRT was linked against cuDNN 8.9.6 but loaded cuDNN 8.9.2 [TensorRT-LLM][INFO] [MemUsageChange] TensorRT-managed allocation in IExecutionContext creation: CPU +0, GPU +0, now: CPU 0, GPU 6282 (MiB) [TensorRT-LLM][INFO] Max tokens in paged KV cache: 364288. Allocating 23873978368 bytes. [TensorRT-LLM][INFO] Max KV cache pages per sequence: 1 terminate called after throwing an instance of 'std::runtime_error' what(): [TensorRT-LLm Error][fpA_intB Runner] Failed to run cutlass fpA_intB gemm. Error: Error Internal [n147-181-220:77056] *** Process received signal *** [n147-181-220:77056] Signal: Aborted (6) [n147-181-220:77056] Signal code: (-6) [n147-181-220:77056] [ 0] /lib/x86_64-linux-gnu/libpthread.so.0(+0x13140)[0x7f72879d9140] [n147-181-220:77056] [ 1] /lib/x86_64-linux-gnu/libc.so.6(gsignal+0x141)[0x7f72876d9ce1] [n147-181-220:77056] [ 2] /lib/x86_64-linux-gnu/libc.so.6(abort+0x123)[0x7f72876c3537] [n147-181-220:77056] [ 3] /opt/tiger/miniconda3/bin/../lib/libstdc++.so.6(_ZN9__gnu_cxx27__verbose_terminate_handlerEv+0xc0)[0x7f72773fdf00] [n147-181-220:77056] [ 4] /opt/tiger/miniconda3/bin/../lib/libstdc++.so.6(+0xb643c)[0x7f72773fc43c] [n147-181-220:77056] [ 5] /opt/tiger/miniconda3/bin/../lib/libstdc++.so.6(+0xb57ff)[0x7f72773fb7ff] [n147-181-220:77056] [ 6] /opt/tiger/miniconda3/bin/../lib/libstdc++.so.6(__gxx_personality_v0+0x356)[0x7f72773fc07f] [n147-181-220:77056] [ 7] /opt/tiger/miniconda3/lib/python3.10/site-packages/numpy/core/../../../../libgcc_s.so.1(+0x12743)[0x7f7286ea2743] [n147-181-220:77056] [ 8] /opt/tiger/miniconda3/lib/python3.10/site-packages/numpy/core/../../../../libgcc_s.so.1(_Unwind_Resume+0x65)[0x7f7286ea2d04] [n147-181-220:77056] [ 9] /opt/tiger/miniconda3/lib/python3.10/site-packages/tensorrt_llm/libs/libnvinfer_plugin_tensorrt_llm.so(+0x1c5edac)[0x7f707a307dac] [n147-181-220:77056] [10] /opt/tiger/miniconda3/lib/python3.10/site-packages/tensorrt_llm/libs/libnvinfer_plugin_tensorrt_llm.so(+0x1c74eb8)[0x7f707a31deb8] [n147-181-220:77056] [11] /opt/tiger/miniconda3/lib/python3.10/site-packages/tensorrt_llm/libs/libnvinfer_plugin_tensorrt_llm.so(_ZN12tensorrt_llm7plugins27WeightOnlyQuantMatmulPlugin7enqueueEPKN8nvinfer116PluginTensorDescES5_PKPKvPKPvSA_P11CUstream_st+0x132)[0x7f7078815472] [n147-181-220:77056] [12] /opt/tiger/miniconda3/lib/python3.10/site-packages/tensorrt_libs/libnvinfer.so.9(+0x10e45e9)[0x7f716a1ee5e9] [n147-181-220:77056] [13] /opt/tiger/miniconda3/lib/python3.10/site-packages/tensorrt_libs/libnvinfer.so.9(+0x10a86af)[0x7f716a1b26af] [n147-181-220:77056] [14] /opt/tiger/miniconda3/lib/python3.10/site-packages/tensorrt_libs/libnvinfer.so.9(+0x10aa320)[0x7f716a1b4320] [n147-181-220:77056] [15] /opt/tiger/miniconda3/lib/python3.10/site-packages/tensorrt_llm/libs/libtensorrt_llm.so(_ZN12tensorrt_llm7runtime10GptSession18executeContextStepERKSt6vectorINS0_15GenerationInputESaIS3_EERKS2_IiSaIiEEPKNS_13batch_manager16kv_cache_manager14KVCacheManagerE+0x377)[0x7f71039b2ae7] [n147-181-220:77056] [16] /opt/tiger/miniconda3/lib/python3.10/site-packages/tensorrt_llm/libs/libtensorrt_llm.so(_ZN12tensorrt_llm7runtime10GptSession15generateBatchedERSt6vectorINS0_16GenerationOutputESaIS3_EERKS2_INS0_15GenerationInputESaIS7_EERKNS0_14SamplingConfigERKSt8functionIFvibEESt10shared_ptrINS1_18GenerationProfilerEE+0xc9e)[0x7f71039b3dfe] [n147-181-220:77056] [17] /opt/tiger/miniconda3/lib/python3.10/site-packages/tensorrt_llm/libs/libtensorrt_llm.so(_ZN12tensorrt_llm7runtime10GptSession8generateERNS0_16GenerationOutputERKNS0_15GenerationInputERKNS0_14SamplingConfigESt10shared_ptrINS1_18GenerationProfilerEE+0x84c)[0x7f71039b544c] [n147-181-220:77056] [18] /opt/tiger/miniconda3/lib/python3.10/site-packages/tensorrt_llm/bindings.cpython-310-x86_64-linux-gnu.so(+0x64712)[0x7f71072e3712] [n147-181-220:77056] [19] /opt/tiger/miniconda3/lib/python3.10/site-packages/tensorrt_llm/bindings.cpython-310-x86_64-linux-gnu.so(+0x4b5c7)[0x7f71072ca5c7] [n147-181-220:77056] [20] python3(+0x1445a6)[0x55b65c3db5a6] [n147-181-220:77056] [21] python3(_PyObject_MakeTpCall+0x26b)[0x55b65c3d4a6b] [n147-181-220:77056] [22] python3(+0x150866)[0x55b65c3e7866] [n147-181-220:77056] [23] python3(_PyEval_EvalFrameDefault+0x4c12)[0x55b65c3d0142] [n147-181-220:77056] [24] python3(+0x150582)[0x55b65c3e7582] [n147-181-220:77056] [25] python3(PyObject_Call+0xbc)[0x55b65c3e7f1c] [n147-181-220:77056] [26] python3(_PyEval_EvalFrameDefault+0x2d83)[0x55b65c3ce2b3] [n147-181-220:77056] [27] python3(_PyFunction_Vectorcall+0x6c)[0x55b65c3dba2c] [n147-181-220:77056] [28] python3(_PyEval_EvalFrameDefault+0x320)[0x55b65c3cb850] [n147-181-220:77056] [29] python3(+0x1d7c60)[0x55b65c46ec60] [n147-181-220:77056] *** End of error message *** Aborted (core dumped)

additional notes

The support matrix for Volta arch shows that w8 is supported. One similar issue is https://github.com/NVIDIA/TensorRT-LLM/issues/1155 but still different.

YihengBrianWu avatar May 22 '24 12:05 YihengBrianWu

@YihengBrianWu could you please provide more details about the model you used? This can help on reproducing this issue, thanks.

Barry-Delaney avatar May 23 '24 08:05 Barry-Delaney

@YihengBrianWu could you please provide more details about the model you used? This can help on reproducing this issue, thanks.

Hi, the model I used is Yi-6b with sft. I also convert a fp16 engine, and fp16 version can inference without error.

YihengBrianWu avatar May 23 '24 08:05 YihengBrianWu

I tested the same model on T4 FP16 is fine everything ok. However, when i changed precision to w8, the engine build part failed. The script is exactly the same compare to V100 case.

The error message is: void cutlass::gemm::kernel::GemmFpAIntB<Mma_, Epilogue_, ThreadblockSwizzle_, KernelArch, SplitKSerial>::run_kernel(const cutlass::gemm::kernel::GemmFpAIntB<Mma_, Epilogue_, ThreadblockSwizzle_, KernelArch, SplitKSerial>::Params &, cutlass::gemm::kernel::GemmFpAIntB<Mma_, Epilogue_, ThreadblockSwizzle_, KernelArch, SplitKSerial>::SharedStorage &) [with CompilationArch = cutlass::arch::Sm70; Mma_ = cutlass::gemm::threadblock::DqMmaPipelined<cutlass::gemm::GemmShape<16, 128, 64>, cutlass::transform::threadblock::PredicatedTileIterator<cutlass::MatrixShape<16, 64>, cutlass::half_t, cutlass::layout::RowMajor, 1, cutlass::transform::PitchLinearWarpRakedThreadMap<cutlass::PitchLinearShape<64, 16>, 128, cutlass::PitchLinearShape<8, 4>, 8>, 8, false, cutlass::layout::NoPermute>, cutlass::transform::threadblock::RegularTileIterator<cutlass::MatrixShape<16, 64>, cutlass::half_t, cutlass::layout::RowMajorTensorOpMultiplicandCrosswise<16, 64>, 0, cutlass::transform::PitchLinearWarpRakedThreadMap<cutlass::PitchLinearShape<64, 16>, 128, cutlass::PitchLinearShape<8, 4>, 8>, 16>, cutlass::transform::threadblock::PredicatedTileIterator<cutlass::MatrixShape<128, 64>, unsigned char, cutlass::layout::ColumnMajor, 0, cutlass::transform::PitchLinearWarpRakedThreadMap<cutlass::PitchLinearShape<128, 64>, 128, cutlass::PitchLinearShape<8, 4>, 16>, 16, false, cutlass::layout::NoPermute>, cutlass::transform::threadblock::RegularTileIterator<cutlass::MatrixShape<64, 128>, unsigned char, cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise<8, 64>, 1, cutlass::transform::PitchLinearWarpRakedThreadMap<cutlass::PitchLinearShape<64, 128>, 128, cutlass::PitchLinearShape<4, 8>, 16>, 16>, cutlass::transform::threadblock::PredicatedTileIterator<cutlass::MatrixShape<1, 128>, cutlass::half_t, cutlass::layout::RowMajor, 0, cutlass::transform::PitchLinearStripminedThreadMap<cutlass::PitchLinearShape<128, 1>, 16, 8>, 8, false, cutlass::layout::NoPermute>, cutlass::transform::threadblock::PredicatedTileIterator<cutlass::MatrixShape<1, 128>, cutlass::half_t, cutlass::layout::RowMajor, 0, cutlass::transform::PitchLinearStripminedThreadMap<cutlass::PitchLinearShape<128, 1>, 16, 8>, 8, false, cutlass::layout::NoPermute>, float, cutlass::layout::RowMajor, cutlass::gemm::threadblock::MmaPolicy<cutlass::gemm::warp::MmaTensorOpComputeBWithF16<cutlass::gemm::GemmShape<16, 32, 64>, cutlass::half_t, cutlass::layout::RowMajorTensorOpMultiplicandCrosswise<16, 64>, unsigned char, cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise<8, 64>, float, cutlass::layout::RowMajor, cutlass::gemm::warp::MmaTensorOpPolicy<cutlass::arch::Mma<cutlass::gemm::GemmShape<16, 8, 8>, 32, cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, cutlass::layout::ColumnMajor, float, cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAdd>, cutlass::MatrixShape<1, 1>>, cutlass::gemm::GemmShape<16, 8, 16>, 1, false, __nv_bool>, cutlass::MatrixShape<0, 0>, cutlass::MatrixShape<0, 0>, 1>, cutlass::NumericArrayConverter<unsigned char, unsigned char, 64, cutlass::FloatRoundStyle::round_to_nearest, cutlass::transform::thread::UnaryTransform::Identity>, cutlass::FastInterleavedAndBiasedNumericArrayConverter<cutlass::half_t, unsigned char, 16>, cutlass::WeightOnlyQuantOp::PER_COLUMN_SCALE_ONLY, _nv_bool>; Epilogue = cutlass::epilogue::threadblock::Epilogue<cutlass::gemm::GemmShape<16, 128, 64>, cutlass::gemm::warp::MmaTensorOpComputeBWithF16<cutlass::gemm::GemmShape<16, 32, 64>, cutlass::half_t, cutlass::layout::RowMajorTensorOpMultiplicandCrosswise<16, 64>, unsigned char, cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise<8, 64>, float, cutlass::layout::RowMajor, cutlass::gemm::warp::MmaTensorOpPolicy<cutlass::arch::Mma<cutlass::gemm::GemmShape<16, 8, 8>, 32, cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, cutlass::layout::ColumnMajor, float, cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAdd>, cutlass::MatrixShape<1, 1>>, cutlass::gemm::GemmShape<16, 8, 16>, 1, false, _nv_bool>, 1, cutlass::epilogue::threadblock::PredicatedTileIterator<cutlass::epilogue::threadblock::OutputTileOptimalThreadMap<cutlass::epilogue::threadblock::OutputTileShape<128, 8, 1, 1, 1>, cutlass::epilogue::threadblock::OutputTileShape<1, 2, 1, 1, 2>, 128, 8, 16>, cutlass::half_t, false, cutlass::layout::NoPermute, false>, cutlass::epilogue::warp::FragmentIteratorTensorOp<cutlass::gemm::GemmShape<16, 32, 64>, cutlass::gemm::GemmShape<16, 8, 8>, float, cutlass::Array<float, 4, true>, cutlass::layout::RowMajor>, cutlass::epilogue::warp::TileIteratorTensorOpMixed<cutlass::gemm::GemmShape<16, 32, 64>, cutlass::gemm::GemmShape<16, 8, 8>, float, 32, 16, 8, 8, false>, cutlass::epilogue::threadblock::SharedLoadIteratorMixed<cutlass::epilogue::threadblock::OutputTileOptimalThreadMap<cutlass::epilogue::threadblock::OutputTileShape<128, 8, 1, 1, 1>, cutlass::epilogue::threadblock::OutputTileShape<1, 2, 1, 1, 2>, 128, 8, 16>::CompactedThreadMap, float, 32, 16, 8, 8, false>, cutlass::epilogue::thread::LinearCombination<cutlass::half_t, 8, float, float, cutlass::epilogue::thread::ScaleType::NoBetaScaling, cutlass::FloatRoundStyle::round_to_nearest, cutlass::half_t>, cutlass::MatrixShape<0, 8>, 2, 1>; ThreadblockSwizzle = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>; KernelArch = cutlass::arch::Sm75; __nv_bool SplitKSerial = true] not implemented

YihengBrianWu avatar May 24 '24 08:05 YihengBrianWu

Thanks for the information. Let me try to reproduce and investigate on this.

Barry-Delaney avatar May 24 '24 08:05 Barry-Delaney

Hi @YihengBrianWu, could you please try the tests on the latest main branch as I ran it successfully on V100s, thanks!

Barry-Delaney avatar May 28 '24 14:05 Barry-Delaney

Hi @YihengBrianWu, could you please try the tests on the latest main branch as I ran it successfully on V100s, thanks!

Thanks! I also test the latest version code on V100 and it works with w8/w4 precision! However T4 is still not work.

v0.11.0 seems request pynvml==0.15.0, but when I convert model on T4 with pynvml==0.15.0 it raise error: pynvml.nvml.NVMLError_FunctionNotFound: Function Not Found

When I downgrade pynvml to 0.14.0, the convert part raise the same error I reported few days ago. Can you help to check this out?

YihengBrianWu avatar May 29 '24 08:05 YihengBrianWu

I tried the latest main with pynvml == 11.4.0/11.4.1/11.5.0, and all these versions work fine in the conversion on T4. Could you please try clean build and share more details about the steps of reproduction if the error still exists?

Barry-Delaney avatar May 30 '24 08:05 Barry-Delaney

Hi @YihengBrianWu , could we close this ticket now?

nv-guomingz avatar Jun 06 '24 13:06 nv-guomingz