torchjd icon indicating copy to clipboard operation
torchjd copied to clipboard

How to use torchjd with complex inputs

Open jixiedy opened this issue 8 months ago • 6 comments

Hello, does torchjd not support input or intermediate calculations involving complex numbers? This is my situation, and I have tried many modifications, but I still cannot successfully use torchjd. Is there a way to ensure that complex numbers can also be used successfully with torchjd?

jixiedy avatar Jun 14 '25 08:06 jixiedy

Hi! Thanks for the issue. I don't know much about differentiation in the complex domain, so could you provide a minimal reproducible example of something involving complex numbers that fails, along with the result that you would have expected? Maybe use the Mean aggregator to simplify things while we're debugging!

ValerianRey avatar Jun 14 '25 10:06 ValerianRey

Hello, I find it challenging to write a completely reproducible script because errors often occur after a large number of iterations (based on 3DGS three-dimensional reconstruction research, which involves Fourier transforms and multiple complex calculations during the intermediate calculation process). I can provide the most common error I encounter when using torchjd, as shown below:

############ Warning: TorchJD backward failed at iteration 29091, skipping this step TorchJD backward failed: CUDA error: invalid configuration argument CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect. For debugging, consider passing CUDA_LAUNCH_BLOCKING=1 Compile with TORCH_USE_CUDA_DSA to enable device-side assertions. ...... Warning: TorchJD backward failed at iteration 29998, skipping this step TorchJD backward failed: CUDA error: invalid configuration argument CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect. For debugging, consider passing CUDA_LAUNCH_BLOCKING=1 Compile with TORCH_USE_CUDA_DSA to enable device-side assertions. Falling back to traditional backward pass Fallback backward also failed: CUDA error: invalid configuration argument CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect. For debugging, consider passing CUDA_LAUNCH_BLOCKING=1 Compile with TORCH_USE_CUDA_DSA to enable device-side assertions. ############

Please provide further assistance based on the information above.

jixiedy avatar Jun 15 '25 06:06 jixiedy

I have never seen this error before, I will need more details about the function you're optimizing.

Is it a function from R^n to R^m with complex internal activations, or is it a function from C^n to R^m, or even something else? Is the Jacobian matrix real-valued or complex-valued?

Note: you can register a hook to the aggregator to inspect the Jacobian matrix, see this for something similar.

Also, what aggregator are you using?

Currently, many of our aggregators are only defined on real-valued Jacobians, so I don't really know how they would behave with complex-valued Jacobians. But I'm not even sure that your error is even related to that.

Please try to use CUDA_LAUNCH_BLOCKING=1 environment variable when launching your training script, and print the full stack trace of the obtained error.

ValerianRey avatar Jun 15 '25 12:06 ValerianRey

After several days of experimentation, I have temporarily resolved the issue of errors occurring when using torchjd with multiple intermediate variables (though this may only be a temporary fix). Unfortunately, despite trying nearly all available aggregators, none have performed satisfactorily in my 3D reconstruction task based on 3DGS. This likely isn’t an issue with torchjd itself, but rather a mismatch between torchjd and my specific task.

Regardless, I sincerely appreciate your patience and detailed response.

jixiedy avatar Jun 21 '25 06:06 jixiedy

Could you describe your fix, so that we can document it for other users? Just out of curiosity, was the Jacobian matrix complex? Because it may be worth making some aggregators accept complex valued matrices, for instance any aggregator that is based on the Gramian of the Jacobian should still work as we could have $G= J J^\dagger$ isntead of $G=J J^\top$.

PierreQuinton avatar Jun 21 '25 08:06 PierreQuinton

Of course, I am more than happy to share my final solution.

Initially, I considered directly modifying the source code to accept complex numbers, but after weighing the time cost and cost-effectiveness, I ultimately decided against it.

Since my project involves a large number of complex number computations, and the numerical stability of complex numbers might be the fundamental issue, I wrote two relatively generic script functions to temporarily resolve my problem. The code for the script functions is as follows:

@staticmethod
def stabilize_for_torchjd(tensor: torch.Tensor, tensor_name: str = "tensor") -> torch.Tensor:
    """
    Numerical stability processing optimized for TorchJD.

    Args:
        tensor: The tensor to be stabilized.
        tensor_name: The name of the tensor (for debugging purposes).

    Returns:
        The stabilized tensor.
    """
    if torch.isnan(tensor).any() or torch.isinf(tensor).any():
        print(f"[TorchJD Stability] Detected NaN/Inf in {tensor_name}, attempting to fix...")
        if torch.is_complex(tensor):
            tensor = torch.nan_to_num(tensor, nan=0.0+0j, posinf=1.0+0j, neginf=0.0+0j)
        else:
            tensor = torch.nan_to_num(tensor, nan=0.0, posinf=1.0, neginf=0.0)
    return tensor

def stabilize_sqrt_computation(term_inside_sqrt: torch.Tensor) -> torch.Tensor:
    """
    Stabilizes square root computation.

    Args:
        term_inside_sqrt: The term inside the square root.

    Returns:
        The stabilized square root result.
    """
    # First, fix NaN/Inf
    term_inside_sqrt = stabilize_for_torchjd(term_inside_sqrt, "sqrt_input")
    # Clamp the range and then take the square root
    return torch.sqrt(torch.clamp(term_inside_sqrt, min=0.0, max=1.0))

jixiedy avatar Jun 21 '25 09:06 jixiedy