[BUG]FP8Linear.forward() argument mismatch in LTX-Video inference
Encountering a TypeError during LTX-Video inference where FP8Linear.forward() receives 5 positional arguments but only accepts 2-4. This appears to be related to the q8_kernels integration with the diffusers pipeline.
Error Details
TypeError: FP8Linear.forward() takes from 2 to 4 positional arguments but 5 were given
Full Stack Trace
/workspace/LTX-Video/env/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html
from .autonotebook import tqdm as notebook_tqdm
Padded dimensions: 960x1280x81
Loading checkpoint shards: 100%|██████████| 2/2 [00:00<00:00, 55.41it/s]
/workspace/LTX-Video/env/lib/python3.11/site-packages/torch/functional.py:554: UserWarning: torch.meshgrid: in an upcoming release, it will be required to pass the indexing argument. (Triggered internally at /pytorch/aten/src/ATen/native/TensorShape.cpp:4314.)
return _VF.meshgrid(tensors, **kwargs) # type: ignore[attr-defined]
0%| | 0/7 [00:00<?, ?it/s]
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
Cell In[1], line 8
5 WIDTH = 1280
6 NUM_FRAMES = 81
----> 8 infer(
9 InferenceConfig(
10 pipeline_config="configs/ltxv-13b-0.9.8-distilled-fp8.yaml",
11 prompt=PROMPT,
12 height=HEIGHT,
13 width=WIDTH,
14 num_frames=NUM_FRAMES,
15 output_path="/workspace/output.mp4",
16 )
17 )
File /workspace/LTX-Video/ltx_video/inference.py:569, in infer(config)
560 sample = {
561 "prompt": config.prompt,
562 "prompt_attention_mask": None,
563 "negative_prompt": config.negative_prompt,
564 "negative_prompt_attention_mask": None,
565 }
567 generator = torch.Generator(device=device).manual_seed(config.seed)
--> 569 images = pipeline(
570 **pipeline_config,
571 skip_layer_strategy=skip_layer_strategy,
572 generator=generator,
573 output_type="pt",
574 callback_on_step_end=None,
575 height=height_padded,
576 width=width_padded,
577 num_frames=num_frames_padded,
578 frame_rate=config.frame_rate,
579 **sample,
580 media_items=media_item,
581 conditioning_items=conditioning_items,
582 is_video=True,
583 vae_per_channel_normalize=True,
584 image_cond_noise_scale=config.image_cond_noise_scale,
585 mixed_precision=(precision == "mixed_precision"),
586 offload_to_cpu=offload_to_cpu,
587 device=device,
588 enhance_prompt=enhance_prompt,
589 ).images
File /workspace/LTX-Video/ltx_video/pipelines/pipeline_ltx_video.py:1865, in LTXMultiScalePipeline.__call__(self, downscale_factor, first_pass, second_pass, *args, **kwargs)
1863 kwargs["height"] = downscaled_height
1864 kwargs.update(**first_pass)
-> 1865 result = self.video_pipeline(*args, **kwargs)
1866 latents = result.images
1868 upsampled_latents = self._upsample_latents(self.latent_upsampler, latents)
File /workspace/LTX-Video/env/lib/python3.11/site-packages/torch/utils/_contextlib.py:116, in context_decorator.<locals>.decorate_context(*args, **kwargs)
113 @functools.wraps(func)
114 def decorate_context(*args, **kwargs):
115 with ctx_factory():
--> 116 return func(*args, **kwargs)
File /workspace/LTX-Video/ltx_video/pipelines/pipeline_ltx_video.py:1206, in LTXVideoPipeline.__call__(self, height, width, num_frames, frame_rate, prompt, negative_prompt, num_inference_steps, skip_initial_inference_steps, skip_final_inference_steps, timesteps, guidance_scale, cfg_star_rescale, skip_layer_strategy, skip_block_list, stg_scale, rescaling_scale, guidance_timesteps, num_images_per_prompt, eta, generator, latents, prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask, output_type, return_dict, callback_on_step_end, conditioning_items, decode_timestep, decode_noise_scale, mixed_precision, offload_to_cpu, enhance_prompt, text_encoder_max_tokens, stochastic_sampling, media_items, tone_map_compression_ratio, **kwargs)
1204 # predict noise model_output
1205 with context_manager:
-> 1206 noise_pred = self.transformer(
1207 latent_model_input.to(self.transformer.dtype),
1208 indices_grid=fractional_coords,
1209 encoder_hidden_states=prompt_embeds_batch[indices].to(
1210 self.transformer.dtype
1211 ),
1212 encoder_attention_mask=prompt_attention_mask_batch[indices],
1213 timestep=current_timestep,
1214 skip_layer_mask=skip_layer_mask,
1215 skip_layer_strategy=skip_layer_strategy,
1216 return_dict=False,
1217 )[0]
1219 # perform guidance
1220 if do_spatio_temporal_guidance:
File /workspace/LTX-Video/env/lib/python3.11/site-packages/torch/nn/modules/module.py:1751, in Module._wrapped_call_impl(self, *args, **kwargs)
1749 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1750 else:
-> 1751 return self._call_impl(*args, **kwargs)
File /workspace/LTX-Video/env/lib/python3.11/site-packages/torch/nn/modules/module.py:1762, in Module._call_impl(self, *args, **kwargs)
1757 # If we don't have any hooks, we want to skip the rest of the logic in
1758 # this function, and just call forward.
1759 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1760 or _global_backward_pre_hooks or _global_backward_hooks
1761 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1762 return forward_call(*args, **kwargs)
File /workspace/LTX-Video/ltx_video/models/transformers/transformer3d.py:478, in Transformer3DModel.forward(self, hidden_states, indices_grid, encoder_hidden_states, timestep, class_labels, cross_attention_kwargs, attention_mask, encoder_attention_mask, skip_layer_mask, skip_layer_strategy, return_dict)
459 hidden_states = torch.utils.checkpoint.checkpoint(
460 create_custom_forward(block),
461 hidden_states,
(...) 475 **ckpt_kwargs,
476 )
477 else:
--> 478 hidden_states = block(
479 hidden_states,
480 freqs_cis=freqs_cis,
481 attention_mask=attention_mask,
482 encoder_hidden_states=encoder_hidden_states,
483 encoder_attention_mask=encoder_attention_mask,
484 timestep=timestep,
485 cross_attention_kwargs=cross_attention_kwargs,
486 class_labels=class_labels,
487 skip_layer_mask=(
488 skip_layer_mask[block_idx]
489 if skip_layer_mask is not None
490 else None
491 ),
492 skip_layer_strategy=skip_layer_strategy,
493 )
File /workspace/LTX-Video/env/lib/python3.11/site-packages/torch/nn/modules/module.py:1751, in Module._wrapped_call_impl(self, *args, **kwargs)
1749 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1750 else:
-> 1751 return self._call_impl(*args, **kwargs)
File /workspace/LTX-Video/env/lib/python3.11/site-packages/torch/nn/modules/module.py:1762, in Module._call_impl(self, *args, **kwargs)
1757 # If we don't have any hooks, we want to skip the rest of the logic in
1758 # this function, and just call forward.
1759 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1760 or _global_backward_pre_hooks or _global_backward_hooks
1761 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1762 return forward_call(*args, **kwargs)
File /workspace/LTX-Video/env/lib/python3.11/site-packages/q8_kernels/integration/diffusers.py:380, in create_forwards.<locals>.fused_forward(self, hidden_states, freqs_cis, attention_mask, encoder_hidden_states, encoder_attention_mask, timestep, cross_attention_kwargs, class_labels, sharding_mesh, skip_layer_mask, skip_layer_strategy)
375 # 1. Prepare GLIGEN inputs
376 cross_attention_kwargs = (
377 cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {}
378 )
--> 380 attn_output = self.attn1(
381 norm_hidden_states,
382 norm_hidden_states_scales,
383 freqs_cis=freqs_cis,
384 encoder_hidden_states=(
385 encoder_hidden_states if self.only_cross_attention else None
386 ),
387 attention_mask=attention_mask,
388 skip_layer_mask=skip_layer_mask,
389 skip_layer_strategy=skip_layer_strategy,
390 **cross_attention_kwargs,
391 )
392 if gate_msa is not None:
393 attn_output = gate_msa * attn_output
File /workspace/LTX-Video/env/lib/python3.11/site-packages/torch/nn/modules/module.py:1751, in Module._wrapped_call_impl(self, *args, **kwargs)
1749 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1750 else:
-> 1751 return self._call_impl(*args, **kwargs)
File /workspace/LTX-Video/env/lib/python3.11/site-packages/torch/nn/modules/module.py:1762, in Module._call_impl(self, *args, **kwargs)
1757 # If we don't have any hooks, we want to skip the rest of the logic in
1758 # this function, and just call forward.
1759 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1760 or _global_backward_pre_hooks or _global_backward_hooks
1761 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1762 return forward_call(*args, **kwargs)
File /workspace/LTX-Video/env/lib/python3.11/site-packages/q8_kernels/integration/diffusers.py:67, in attn_forward(self, hidden_states, hidden_states_scales, freqs_cis, encoder_hidden_states, attention_mask, skip_layer_mask, skip_layer_strategy, **cross_attention_kwargs)
59 logger.warning(
60 f"cross_attention_kwargs {unused_kwargs} are not expected by"
61 f" {self.processor.__class__.__name__} and will be ignored."
62 )
63 cross_attention_kwargs = {
64 k: w for k, w in cross_attention_kwargs.items() if k in attn_parameters
65 }
---> 67 return self.processor(
68 self,
69 hidden_states,
70 hidden_states_scales,
71 freqs_cis=freqs_cis,
72 encoder_hidden_states=encoder_hidden_states,
73 attention_mask=attention_mask,
74 skip_layer_mask=skip_layer_mask,
75 skip_layer_strategy=skip_layer_strategy,
76 **cross_attention_kwargs,
77 )
File /workspace/LTX-Video/env/lib/python3.11/site-packages/q8_kernels/integration/diffusers.py:158, in create_attn_processor.<locals>.AttnProcessor3_0.__call__(self, attn, hidden_states, hidden_states_scales, freqs_cis, encoder_hidden_states, attention_mask, temb, skip_layer_mask, skip_layer_strategy, *args, **kwargs)
155 else: # if no context provided do self-attention
156 is_self_attention = True
--> 158 query = attn.to_q(
159 hidden_states, hidden_states_scales, False, torch.bfloat16
160 )
161 query = rms_norm_rope(
162 query, freqs_cis[0], freqs_cis[1], attn.q_norm.weight
163 )
165 key = attn.to_k(
166 hidden_states, hidden_states_scales, False, torch.bfloat16
167 )
File /workspace/LTX-Video/env/lib/python3.11/site-packages/torch/nn/modules/module.py:1751, in Module._wrapped_call_impl(self, *args, **kwargs)
1749 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1750 else:
-> 1751 return self._call_impl(*args, **kwargs)
File /workspace/LTX-Video/env/lib/python3.11/site-packages/torch/nn/modules/module.py:1762, in Module._call_impl(self, *args, **kwargs)
1757 # If we don't have any hooks, we want to skip the rest of the logic in
1758 # this function, and just call forward.
1759 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1760 or _global_backward_pre_hooks or _global_backward_hooks
1761 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1762 return forward_call(*args, **kwargs)
TypeError: FP8Linear.forward() takes from 2 to 4 positional arguments but 5 were given
Reproduction Steps
- Set up LTX-Video environment with FP8 distilled model
- Configure inference with the following parameters:
PROMPT = [your_prompt_here] HEIGHT = 960 WIDTH = 1280 NUM_FRAMES = 81 infer( InferenceConfig( pipeline_config="configs/ltxv-13b-0.9.8-distilled-fp8.yaml", prompt=PROMPT, height=HEIGHT, width=WIDTH, num_frames=NUM_FRAMES, output_path="/workspace/output.mp4", ) ) - Execute inference - error occurs during transformer forward pass
Root Cause Analysis
The error occurs in the q8_kernels integration where attn.to_q() is called with 4 arguments:
query = attn.to_q(
hidden_states, hidden_states_scales, False, torch.bfloat16
)
However, the FP8Linear.forward() method only accepts 2-4 positional arguments, suggesting a signature mismatch between the expected interface and the actual implementation.
Environment Information
-
Model Config:
configs/ltxv-13b-0.9.8-distilled-fp8.yaml - Python: 3.11
- PyTorch: 2.7.1+cu126
- q8_kernels: q8_kernels @ file:///workspace/LTX-Video/LTX-Video-Q8-Kernels/dist/q8_kernels-0.0.5-cp311-cp311-linux_x86_64.whl#sha256=54e6ab58fa9acccfd60316e876e0807ced33df91929f59415d9b57a893f0acb7
- Platform: Linux workspace environment
same error here? does this mean I cannot use fp8 model?
same error
I'm getting the same error.
I've tried debugging the issue a bit, it seems that following callbacks in q8_kernels integration with the diffusers pipeline (file .../q8_kernels/integrations/diffusers.py) call FP8Linear.forward(...) : line 149 (attn.to_q), line 152 (attn.to_k), line 154 (attn.to_v), line 158 (attn.to_q), line 165 (attn.to_k), line 170 (attn.to_v), line 284 (attn.to_out), line 452 (self.proj), line 469 (self.net[2]). In all of these places a boolean parameter (True or False depending on the invocation) is passed to FP8Linear.forward(...). The latter doesn't expect any boolean parameters so I assume there's a problem in either how the callback function is assigned or the invocation itself. As there isn't a lot of information in the repository to go on, I didn't manage to trace these changes and guess what the semantics of that boolean parameter are, maybe someone changed use_hadamard from being passed as a FP8Linear.forward(...) parameter to a class property of FP8Linear and forgot to change the invocations?
For now I have added a band-aid in my Python environment by changing all the invocations above not to pass the boolean parameter, the inference works after that, but that's most likely not the proper fix.
PS. It may be better to somehow move this issue to LTX-Video-Q8-Kernels repository as this problem is specific to its code.
+1
I've been fighting this today. I naively tried removing the False parameter passed within the code in diffusers.py to:
query = attn.to_q(hidden_states, hidden_states_scales, torch.bfloat16)
query = rms_norm_rope(query, freqs_cis[0], freqs_cis[1], attn.q_norm.weight)
key = attn.to_k(hidden_states, hidden_states_scales, torch.bfloat16)
key = rms_norm_rope(key, freqs_cis[0], freqs_cis[1], attn.k_norm.weight)
value = attn.to_v(hidden_states, hidden_states_scales, torch.bfloat16)
That bypassed the original issue above but then later things blew up worse (callstack below). I'm just going to have to forgo using the FP8 stuff.
/home/neil/repos/LTX-Video-Q8-Kernels/csrc/gemm/mma_sm89_fp16.hpp:80: static void cute::SM89_16x8x32_F16E4M3E4M3F16_TN::fma(unsigned int &, unsigned int &, const unsigned int &, const unsigned int &, const unsigned int &, const unsigned int &, const unsigned int &, const unsigned int &, const unsigned int &, const unsigned int &): block: [11,0,0], thread: [127,0,0] Assertion `0 && "Attempting to use SM89_16x8x32_F32E4M3E4M3F32_TN without CUTE_ARCH_MMA_F16_SM89_ENABLED"` failed.
0%| | 0/7 [00:00<?, ?it/s]
Traceback (most recent call last):
File "/usr/lib/python3.12/runpy.py", line 198, in _run_module_as_main
return _run_code(code, main_globals, None,
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/lib/python3.12/runpy.py", line 88, in _run_code
exec(code, run_globals)
File "/home/neil/.vscode-server/extensions/ms-python.debugpy-2025.16.0/bundled/libs/debugpy/adapter/../../debugpy/launcher/../../debugpy/__main__.py", line 71, in <module>
cli.main()
File "/home/neil/.vscode-server/extensions/ms-python.debugpy-2025.16.0/bundled/libs/debugpy/adapter/../../debugpy/launcher/../../debugpy/../debugpy/server/cli.py", line 508, in main
run()
File "/home/neil/.vscode-server/extensions/ms-python.debugpy-2025.16.0/bundled/libs/debugpy/adapter/../../debugpy/launcher/../../debugpy/../debugpy/server/cli.py", line 358, in run_file
runpy.run_path(target, run_name="__main__")
File "/home/neil/.vscode-server/extensions/ms-python.debugpy-2025.16.0/bundled/libs/debugpy/_vendored/pydevd/_pydevd_bundle/pydevd_runpy.py", line 310, in run_path
return _run_module_code(code, init_globals, run_name, pkg_name=pkg_name, script_name=fname)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/neil/.vscode-server/extensions/ms-python.debugpy-2025.16.0/bundled/libs/debugpy/_vendored/pydevd/_pydevd_bundle/pydevd_runpy.py", line 127, in _run_module_code
_run_code(code, mod_globals, init_globals, mod_name, mod_spec, pkg_name, script_name)
File "/home/neil/.vscode-server/extensions/ms-python.debugpy-2025.16.0/bundled/libs/debugpy/_vendored/pydevd/_pydevd_bundle/pydevd_runpy.py", line 118, in _run_code
exec(code, run_globals)
File "/home/neil/repos/LTX-Video/inference.py", line 13, in <module>
main()
File "/home/neil/repos/LTX-Video/inference.py", line 9, in main
infer(config=config)
File "/home/neil/repos/LTX-Video/ltx_video/inference.py", line 569, in infer
images = pipeline(
^^^^^^^^^
File "/home/neil/repos/LTX-Video/ltx_video/pipelines/pipeline_ltx_video.py", line 1865, in __call__
result = self.video_pipeline(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/neil/repos/LTX-Video/venv/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/home/neil/repos/LTX-Video/ltx_video/pipelines/pipeline_ltx_video.py", line 1206, in __call__
noise_pred = self.transformer(
^^^^^^^^^^^^^^^^^
File "/home/neil/repos/LTX-Video/venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/neil/repos/LTX-Video/venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/neil/repos/LTX-Video/ltx_video/models/transformers/transformer3d.py", line 478, in forward
hidden_states = block(
^^^^^^
File "/home/neil/repos/LTX-Video/venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/neil/repos/LTX-Video/venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/neil/repos/LTX-Video/venv/lib/python3.12/site-packages/q8_kernels/integration/diffusers.py", line 370, in fused_forward
attn_output = self.attn1(
^^^^^^^^^^^
File "/home/neil/repos/LTX-Video/venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/neil/repos/LTX-Video/venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/neil/repos/LTX-Video/venv/lib/python3.12/site-packages/q8_kernels/integration/diffusers.py", line 67, in attn_forward
return self.processor(
^^^^^^^^^^^^^^^
File "/home/neil/repos/LTX-Video/venv/lib/python3.12/site-packages/q8_kernels/integration/diffusers.py", line 159, in __call__
query = rms_norm_rope(query, freqs_cis[0], freqs_cis[1], attn.q_norm.weight)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/neil/repos/LTX-Video/venv/lib/python3.12/site-packages/q8_kernels/functional/ops.py", line 256, in rms_norm_rope
return RMSNormRope.apply(x, weights, cos_freqs, sin_freqs, out_16bit)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/neil/repos/LTX-Video/venv/lib/python3.12/site-packages/torch/autograd/function.py", line 575, in apply
return super().apply(*args, **kwargs) # type: ignore[misc]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/neil/repos/LTX-Video/venv/lib/python3.12/site-packages/q8_kernels/functional/ops.py", line 207, in forward
return torch.ops.q8_kernels_ops.rms_norm_rope(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/neil/repos/LTX-Video/venv/lib/python3.12/site-packages/torch/_ops.py", line 1116, in __call__
return self._op(*args, **(kwargs or {}))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/neil/repos/LTX-Video/venv/lib/python3.12/site-packages/torch/_library/autograd.py", line 113, in autograd_impl
result = forward_no_grad(*args, Metadata(keyset, keyword_only_args))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/neil/repos/LTX-Video/venv/lib/python3.12/site-packages/torch/_library/autograd.py", line 40, in forward_no_grad
result = op.redispatch(keyset & _C._after_autograd_keyset, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/neil/repos/LTX-Video/venv/lib/python3.12/site-packages/torch/_ops.py", line 721, in redispatch
return self._handle.redispatch_boxed(keyset, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/neil/repos/LTX-Video/venv/lib/python3.12/site-packages/torch/_library/custom_ops.py", line 324, in backend_impl
result = self._backend_fns[device_type](*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/neil/repos/LTX-Video/venv/lib/python3.12/site-packages/torch/_compile.py", line 32, in inner
return disable_fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/neil/repos/LTX-Video/venv/lib/python3.12/site-packages/torch/_dynamo/eval_frame.py", line 632, in _fn
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/home/neil/repos/LTX-Video/venv/lib/python3.12/site-packages/torch/_library/custom_ops.py", line 367, in wrapped_fn
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/home/neil/repos/LTX-Video/venv/lib/python3.12/site-packages/q8_kernels/functional/ops.py", line 26, in _rms_norm_rope_cuda
return rms_norm_rope_cuda(x, weights, cos_freqs, sin_freqs, out_16bit)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: CUDA error: device-side assert triggered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.