LTX-Video icon indicating copy to clipboard operation
LTX-Video copied to clipboard

[BUG]FP8Linear.forward() argument mismatch in LTX-Video inference

Open ighoshsubho opened this issue 6 months ago • 4 comments

Encountering a TypeError during LTX-Video inference where FP8Linear.forward() receives 5 positional arguments but only accepts 2-4. This appears to be related to the q8_kernels integration with the diffusers pipeline.

Error Details

TypeError: FP8Linear.forward() takes from 2 to 4 positional arguments but 5 were given

Full Stack Trace

/workspace/LTX-Video/env/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html
  from .autonotebook import tqdm as notebook_tqdm
Padded dimensions: 960x1280x81
Loading checkpoint shards: 100%|██████████| 2/2 [00:00<00:00, 55.41it/s]
/workspace/LTX-Video/env/lib/python3.11/site-packages/torch/functional.py:554: UserWarning: torch.meshgrid: in an upcoming release, it will be required to pass the indexing argument. (Triggered internally at /pytorch/aten/src/ATen/native/TensorShape.cpp:4314.)
  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]
  0%|          | 0/7 [00:00<?, ?it/s]
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Cell In[1], line 8
      5 WIDTH = 1280
      6 NUM_FRAMES = 81
----> 8 infer(
      9     InferenceConfig(
     10         pipeline_config="configs/ltxv-13b-0.9.8-distilled-fp8.yaml",
     11         prompt=PROMPT,
     12         height=HEIGHT,
     13         width=WIDTH,
     14         num_frames=NUM_FRAMES,
     15         output_path="/workspace/output.mp4",
     16     )
     17 )

File /workspace/LTX-Video/ltx_video/inference.py:569, in infer(config)
    560 sample = {
    561     "prompt": config.prompt,
    562     "prompt_attention_mask": None,
    563     "negative_prompt": config.negative_prompt,
    564     "negative_prompt_attention_mask": None,
    565 }
    567 generator = torch.Generator(device=device).manual_seed(config.seed)
--> 569 images = pipeline(
    570     **pipeline_config,
    571     skip_layer_strategy=skip_layer_strategy,
    572     generator=generator,
    573     output_type="pt",
    574     callback_on_step_end=None,
    575     height=height_padded,
    576     width=width_padded,
    577     num_frames=num_frames_padded,
    578     frame_rate=config.frame_rate,
    579     **sample,
    580     media_items=media_item,
    581     conditioning_items=conditioning_items,
    582     is_video=True,
    583     vae_per_channel_normalize=True,
    584     image_cond_noise_scale=config.image_cond_noise_scale,
    585     mixed_precision=(precision == "mixed_precision"),
    586     offload_to_cpu=offload_to_cpu,
    587     device=device,
    588     enhance_prompt=enhance_prompt,
    589 ).images

File /workspace/LTX-Video/ltx_video/pipelines/pipeline_ltx_video.py:1865, in LTXMultiScalePipeline.__call__(self, downscale_factor, first_pass, second_pass, *args, **kwargs)
   1863 kwargs["height"] = downscaled_height
   1864 kwargs.update(**first_pass)
-> 1865 result = self.video_pipeline(*args, **kwargs)
   1866 latents = result.images
   1868 upsampled_latents = self._upsample_latents(self.latent_upsampler, latents)

File /workspace/LTX-Video/env/lib/python3.11/site-packages/torch/utils/_contextlib.py:116, in context_decorator.<locals>.decorate_context(*args, **kwargs)
    113 @functools.wraps(func)
    114 def decorate_context(*args, **kwargs):
    115     with ctx_factory():
--> 116         return func(*args, **kwargs)

File /workspace/LTX-Video/ltx_video/pipelines/pipeline_ltx_video.py:1206, in LTXVideoPipeline.__call__(self, height, width, num_frames, frame_rate, prompt, negative_prompt, num_inference_steps, skip_initial_inference_steps, skip_final_inference_steps, timesteps, guidance_scale, cfg_star_rescale, skip_layer_strategy, skip_block_list, stg_scale, rescaling_scale, guidance_timesteps, num_images_per_prompt, eta, generator, latents, prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask, output_type, return_dict, callback_on_step_end, conditioning_items, decode_timestep, decode_noise_scale, mixed_precision, offload_to_cpu, enhance_prompt, text_encoder_max_tokens, stochastic_sampling, media_items, tone_map_compression_ratio, **kwargs)
   1204 # predict noise model_output
   1205 with context_manager:
-> 1206     noise_pred = self.transformer(
   1207         latent_model_input.to(self.transformer.dtype),
   1208         indices_grid=fractional_coords,
   1209         encoder_hidden_states=prompt_embeds_batch[indices].to(
   1210             self.transformer.dtype
   1211         ),
   1212         encoder_attention_mask=prompt_attention_mask_batch[indices],
   1213         timestep=current_timestep,
   1214         skip_layer_mask=skip_layer_mask,
   1215         skip_layer_strategy=skip_layer_strategy,
   1216         return_dict=False,
   1217     )[0]
   1219 # perform guidance
   1220 if do_spatio_temporal_guidance:

File /workspace/LTX-Video/env/lib/python3.11/site-packages/torch/nn/modules/module.py:1751, in Module._wrapped_call_impl(self, *args, **kwargs)
   1749     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1750 else:
-> 1751     return self._call_impl(*args, **kwargs)

File /workspace/LTX-Video/env/lib/python3.11/site-packages/torch/nn/modules/module.py:1762, in Module._call_impl(self, *args, **kwargs)
   1757 # If we don't have any hooks, we want to skip the rest of the logic in
   1758 # this function, and just call forward.
   1759 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1760         or _global_backward_pre_hooks or _global_backward_hooks
   1761         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1762     return forward_call(*args, **kwargs)

File /workspace/LTX-Video/ltx_video/models/transformers/transformer3d.py:478, in Transformer3DModel.forward(self, hidden_states, indices_grid, encoder_hidden_states, timestep, class_labels, cross_attention_kwargs, attention_mask, encoder_attention_mask, skip_layer_mask, skip_layer_strategy, return_dict)
    459         hidden_states = torch.utils.checkpoint.checkpoint(
    460             create_custom_forward(block),
    461             hidden_states,
   (...)    475             **ckpt_kwargs,
    476         )
    477     else:
--> 478         hidden_states = block(
    479             hidden_states,
    480             freqs_cis=freqs_cis,
    481             attention_mask=attention_mask,
    482             encoder_hidden_states=encoder_hidden_states,
    483             encoder_attention_mask=encoder_attention_mask,
    484             timestep=timestep,
    485             cross_attention_kwargs=cross_attention_kwargs,
    486             class_labels=class_labels,
    487             skip_layer_mask=(
    488                 skip_layer_mask[block_idx]
    489                 if skip_layer_mask is not None
    490                 else None
    491             ),
    492             skip_layer_strategy=skip_layer_strategy,
    493         )

File /workspace/LTX-Video/env/lib/python3.11/site-packages/torch/nn/modules/module.py:1751, in Module._wrapped_call_impl(self, *args, **kwargs)
   1749     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1750 else:
-> 1751     return self._call_impl(*args, **kwargs)

File /workspace/LTX-Video/env/lib/python3.11/site-packages/torch/nn/modules/module.py:1762, in Module._call_impl(self, *args, **kwargs)
   1757 # If we don't have any hooks, we want to skip the rest of the logic in
   1758 # this function, and just call forward.
   1759 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1760         or _global_backward_pre_hooks or _global_backward_hooks
   1761         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1762     return forward_call(*args, **kwargs)

File /workspace/LTX-Video/env/lib/python3.11/site-packages/q8_kernels/integration/diffusers.py:380, in create_forwards.<locals>.fused_forward(self, hidden_states, freqs_cis, attention_mask, encoder_hidden_states, encoder_attention_mask, timestep, cross_attention_kwargs, class_labels, sharding_mesh, skip_layer_mask, skip_layer_strategy)
    375 # 1. Prepare GLIGEN inputs
    376 cross_attention_kwargs = (
    377     cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {}
    378 )
--> 380 attn_output = self.attn1(
    381     norm_hidden_states,
    382     norm_hidden_states_scales,
    383     freqs_cis=freqs_cis,
    384     encoder_hidden_states=(
    385         encoder_hidden_states if self.only_cross_attention else None
    386     ),
    387     attention_mask=attention_mask,
    388     skip_layer_mask=skip_layer_mask,
    389     skip_layer_strategy=skip_layer_strategy,
    390     **cross_attention_kwargs,
    391 )
    392 if gate_msa is not None:
    393     attn_output = gate_msa * attn_output

File /workspace/LTX-Video/env/lib/python3.11/site-packages/torch/nn/modules/module.py:1751, in Module._wrapped_call_impl(self, *args, **kwargs)
   1749     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1750 else:
-> 1751     return self._call_impl(*args, **kwargs)

File /workspace/LTX-Video/env/lib/python3.11/site-packages/torch/nn/modules/module.py:1762, in Module._call_impl(self, *args, **kwargs)
   1757 # If we don't have any hooks, we want to skip the rest of the logic in
   1758 # this function, and just call forward.
   1759 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1760         or _global_backward_pre_hooks or _global_backward_hooks
   1761         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1762     return forward_call(*args, **kwargs)

File /workspace/LTX-Video/env/lib/python3.11/site-packages/q8_kernels/integration/diffusers.py:67, in attn_forward(self, hidden_states, hidden_states_scales, freqs_cis, encoder_hidden_states, attention_mask, skip_layer_mask, skip_layer_strategy, **cross_attention_kwargs)
     59     logger.warning(
     60         f"cross_attention_kwargs {unused_kwargs} are not expected by"
     61         f" {self.processor.__class__.__name__} and will be ignored."
     62     )
     63 cross_attention_kwargs = {
     64     k: w for k, w in cross_attention_kwargs.items() if k in attn_parameters
     65 }
---> 67 return self.processor(
     68     self,
     69     hidden_states,
     70     hidden_states_scales,
     71     freqs_cis=freqs_cis,
     72     encoder_hidden_states=encoder_hidden_states,
     73     attention_mask=attention_mask,
     74     skip_layer_mask=skip_layer_mask,
     75     skip_layer_strategy=skip_layer_strategy,
     76     **cross_attention_kwargs,
     77 )

File /workspace/LTX-Video/env/lib/python3.11/site-packages/q8_kernels/integration/diffusers.py:158, in create_attn_processor.<locals>.AttnProcessor3_0.__call__(self, attn, hidden_states, hidden_states_scales, freqs_cis, encoder_hidden_states, attention_mask, temb, skip_layer_mask, skip_layer_strategy, *args, **kwargs)
    155 else:  # if no context provided do self-attention
    156     is_self_attention = True
--> 158     query = attn.to_q(
    159         hidden_states, hidden_states_scales, False, torch.bfloat16
    160     )
    161     query = rms_norm_rope(
    162         query, freqs_cis[0], freqs_cis[1], attn.q_norm.weight
    163     )
    165     key = attn.to_k(
    166         hidden_states, hidden_states_scales, False, torch.bfloat16
    167     )

File /workspace/LTX-Video/env/lib/python3.11/site-packages/torch/nn/modules/module.py:1751, in Module._wrapped_call_impl(self, *args, **kwargs)
   1749     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1750 else:
-> 1751     return self._call_impl(*args, **kwargs)

File /workspace/LTX-Video/env/lib/python3.11/site-packages/torch/nn/modules/module.py:1762, in Module._call_impl(self, *args, **kwargs)
   1757 # If we don't have any hooks, we want to skip the rest of the logic in
   1758 # this function, and just call forward.
   1759 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1760         or _global_backward_pre_hooks or _global_backward_hooks
   1761         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1762     return forward_call(*args, **kwargs)

TypeError: FP8Linear.forward() takes from 2 to 4 positional arguments but 5 were given

Reproduction Steps

  1. Set up LTX-Video environment with FP8 distilled model
  2. Configure inference with the following parameters:
    PROMPT = [your_prompt_here]
    HEIGHT = 960
    WIDTH = 1280
    NUM_FRAMES = 81
    
    infer(
        InferenceConfig(
            pipeline_config="configs/ltxv-13b-0.9.8-distilled-fp8.yaml",
            prompt=PROMPT,
            height=HEIGHT,
            width=WIDTH,
            num_frames=NUM_FRAMES,
            output_path="/workspace/output.mp4",
        )
    )
    
  3. Execute inference - error occurs during transformer forward pass

Root Cause Analysis

The error occurs in the q8_kernels integration where attn.to_q() is called with 4 arguments:

query = attn.to_q(
    hidden_states, hidden_states_scales, False, torch.bfloat16
)

However, the FP8Linear.forward() method only accepts 2-4 positional arguments, suggesting a signature mismatch between the expected interface and the actual implementation.

Environment Information

  • Model Config: configs/ltxv-13b-0.9.8-distilled-fp8.yaml
  • Python: 3.11
  • PyTorch: 2.7.1+cu126
  • q8_kernels: q8_kernels @ file:///workspace/LTX-Video/LTX-Video-Q8-Kernels/dist/q8_kernels-0.0.5-cp311-cp311-linux_x86_64.whl#sha256=54e6ab58fa9acccfd60316e876e0807ced33df91929f59415d9b57a893f0acb7
  • Platform: Linux workspace environment

ighoshsubho avatar Jul 25 '25 09:07 ighoshsubho

same error here? does this mean I cannot use fp8 model?

jinyixin621 avatar Jul 29 '25 20:07 jinyixin621

same error

georgeyfly avatar Sep 29 '25 17:09 georgeyfly

I'm getting the same error.

I've tried debugging the issue a bit, it seems that following callbacks in q8_kernels integration with the diffusers pipeline (file .../q8_kernels/integrations/diffusers.py) call FP8Linear.forward(...) : line 149 (attn.to_q), line 152 (attn.to_k), line 154 (attn.to_v), line 158 (attn.to_q), line 165 (attn.to_k), line 170 (attn.to_v), line 284 (attn.to_out), line 452 (self.proj), line 469 (self.net[2]). In all of these places a boolean parameter (True or False depending on the invocation) is passed to FP8Linear.forward(...). The latter doesn't expect any boolean parameters so I assume there's a problem in either how the callback function is assigned or the invocation itself. As there isn't a lot of information in the repository to go on, I didn't manage to trace these changes and guess what the semantics of that boolean parameter are, maybe someone changed use_hadamard from being passed as a FP8Linear.forward(...) parameter to a class property of FP8Linear and forgot to change the invocations?

For now I have added a band-aid in my Python environment by changing all the invocations above not to pass the boolean parameter, the inference works after that, but that's most likely not the proper fix.

PS. It may be better to somehow move this issue to LTX-Video-Q8-Kernels repository as this problem is specific to its code.

Marat-Khaitov avatar Oct 02 '25 14:10 Marat-Khaitov

+1

maty-bohacek avatar Oct 04 '25 18:10 maty-bohacek

I've been fighting this today. I naively tried removing the False parameter passed within the code in diffusers.py to:

                query = attn.to_q(hidden_states, hidden_states_scales, torch.bfloat16)
                query = rms_norm_rope(query, freqs_cis[0], freqs_cis[1], attn.q_norm.weight)
                key = attn.to_k(hidden_states, hidden_states_scales, torch.bfloat16)
                key = rms_norm_rope(key, freqs_cis[0], freqs_cis[1], attn.k_norm.weight)
                value = attn.to_v(hidden_states, hidden_states_scales, torch.bfloat16)

That bypassed the original issue above but then later things blew up worse (callstack below). I'm just going to have to forgo using the FP8 stuff.

/home/neil/repos/LTX-Video-Q8-Kernels/csrc/gemm/mma_sm89_fp16.hpp:80: static void cute::SM89_16x8x32_F16E4M3E4M3F16_TN::fma(unsigned int &, unsigned int &, const unsigned int &, const unsigned int &, const unsigned int &, const unsigned int &, const unsigned int &, const unsigned int &, const unsigned int &, const unsigned int &): block: [11,0,0], thread: [127,0,0] Assertion `0 && "Attempting to use SM89_16x8x32_F32E4M3E4M3F32_TN without CUTE_ARCH_MMA_F16_SM89_ENABLED"` failed.
  0%|                                                                                                             | 0/7 [00:00<?, ?it/s]
Traceback (most recent call last):
  File "/usr/lib/python3.12/runpy.py", line 198, in _run_module_as_main
    return _run_code(code, main_globals, None,
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/lib/python3.12/runpy.py", line 88, in _run_code
    exec(code, run_globals)
  File "/home/neil/.vscode-server/extensions/ms-python.debugpy-2025.16.0/bundled/libs/debugpy/adapter/../../debugpy/launcher/../../debugpy/__main__.py", line 71, in <module>
    cli.main()
  File "/home/neil/.vscode-server/extensions/ms-python.debugpy-2025.16.0/bundled/libs/debugpy/adapter/../../debugpy/launcher/../../debugpy/../debugpy/server/cli.py", line 508, in main
    run()
  File "/home/neil/.vscode-server/extensions/ms-python.debugpy-2025.16.0/bundled/libs/debugpy/adapter/../../debugpy/launcher/../../debugpy/../debugpy/server/cli.py", line 358, in run_file
    runpy.run_path(target, run_name="__main__")
  File "/home/neil/.vscode-server/extensions/ms-python.debugpy-2025.16.0/bundled/libs/debugpy/_vendored/pydevd/_pydevd_bundle/pydevd_runpy.py", line 310, in run_path
    return _run_module_code(code, init_globals, run_name, pkg_name=pkg_name, script_name=fname)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/neil/.vscode-server/extensions/ms-python.debugpy-2025.16.0/bundled/libs/debugpy/_vendored/pydevd/_pydevd_bundle/pydevd_runpy.py", line 127, in _run_module_code
    _run_code(code, mod_globals, init_globals, mod_name, mod_spec, pkg_name, script_name)
  File "/home/neil/.vscode-server/extensions/ms-python.debugpy-2025.16.0/bundled/libs/debugpy/_vendored/pydevd/_pydevd_bundle/pydevd_runpy.py", line 118, in _run_code
    exec(code, run_globals)
  File "/home/neil/repos/LTX-Video/inference.py", line 13, in <module>
    main()
  File "/home/neil/repos/LTX-Video/inference.py", line 9, in main
    infer(config=config)
  File "/home/neil/repos/LTX-Video/ltx_video/inference.py", line 569, in infer
    images = pipeline(
             ^^^^^^^^^
  File "/home/neil/repos/LTX-Video/ltx_video/pipelines/pipeline_ltx_video.py", line 1865, in __call__
    result = self.video_pipeline(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/neil/repos/LTX-Video/venv/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/neil/repos/LTX-Video/ltx_video/pipelines/pipeline_ltx_video.py", line 1206, in __call__
    noise_pred = self.transformer(
                 ^^^^^^^^^^^^^^^^^
  File "/home/neil/repos/LTX-Video/venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/neil/repos/LTX-Video/venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/neil/repos/LTX-Video/ltx_video/models/transformers/transformer3d.py", line 478, in forward
    hidden_states = block(
                    ^^^^^^
  File "/home/neil/repos/LTX-Video/venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/neil/repos/LTX-Video/venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/neil/repos/LTX-Video/venv/lib/python3.12/site-packages/q8_kernels/integration/diffusers.py", line 370, in fused_forward
    attn_output = self.attn1(
                  ^^^^^^^^^^^
  File "/home/neil/repos/LTX-Video/venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/neil/repos/LTX-Video/venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/neil/repos/LTX-Video/venv/lib/python3.12/site-packages/q8_kernels/integration/diffusers.py", line 67, in attn_forward
    return self.processor(
           ^^^^^^^^^^^^^^^
  File "/home/neil/repos/LTX-Video/venv/lib/python3.12/site-packages/q8_kernels/integration/diffusers.py", line 159, in __call__
    query = rms_norm_rope(query, freqs_cis[0], freqs_cis[1], attn.q_norm.weight)
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/neil/repos/LTX-Video/venv/lib/python3.12/site-packages/q8_kernels/functional/ops.py", line 256, in rms_norm_rope
    return RMSNormRope.apply(x, weights, cos_freqs, sin_freqs, out_16bit)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/neil/repos/LTX-Video/venv/lib/python3.12/site-packages/torch/autograd/function.py", line 575, in apply
    return super().apply(*args, **kwargs)  # type: ignore[misc]
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/neil/repos/LTX-Video/venv/lib/python3.12/site-packages/q8_kernels/functional/ops.py", line 207, in forward
    return torch.ops.q8_kernels_ops.rms_norm_rope(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/neil/repos/LTX-Video/venv/lib/python3.12/site-packages/torch/_ops.py", line 1116, in __call__
    return self._op(*args, **(kwargs or {}))
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/neil/repos/LTX-Video/venv/lib/python3.12/site-packages/torch/_library/autograd.py", line 113, in autograd_impl
    result = forward_no_grad(*args, Metadata(keyset, keyword_only_args))
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/neil/repos/LTX-Video/venv/lib/python3.12/site-packages/torch/_library/autograd.py", line 40, in forward_no_grad
    result = op.redispatch(keyset & _C._after_autograd_keyset, *args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/neil/repos/LTX-Video/venv/lib/python3.12/site-packages/torch/_ops.py", line 721, in redispatch
    return self._handle.redispatch_boxed(keyset, *args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/neil/repos/LTX-Video/venv/lib/python3.12/site-packages/torch/_library/custom_ops.py", line 324, in backend_impl
    result = self._backend_fns[device_type](*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/neil/repos/LTX-Video/venv/lib/python3.12/site-packages/torch/_compile.py", line 32, in inner
    return disable_fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/neil/repos/LTX-Video/venv/lib/python3.12/site-packages/torch/_dynamo/eval_frame.py", line 632, in _fn
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/home/neil/repos/LTX-Video/venv/lib/python3.12/site-packages/torch/_library/custom_ops.py", line 367, in wrapped_fn
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/home/neil/repos/LTX-Video/venv/lib/python3.12/site-packages/q8_kernels/functional/ops.py", line 26, in _rms_norm_rope_cuda
    return rms_norm_rope_cuda(x, weights, cos_freqs, sin_freqs, out_16bit)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: CUDA error: device-side assert triggered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.

neilobremski avatar Dec 12 '25 20:12 neilobremski