deepali icon indicating copy to clipboard operation
deepali copied to clipboard

ncc_loss can not use with mask

Open robert-graf opened this issue 10 months ago • 3 comments

Mask in ncc_loss from deepali.losses.functional does not work.

the image-loss is already collapses to shape=[c] before masked_loss is called, through the sum(dim=1) calls.

robert-graf avatar Mar 20 '25 16:03 robert-graf

as we discussed earlier, the keepdim flag should avoid this in the most recent version. Can you send a minimal example to reproduce the issue if this still doesn't work for you?

qiuhuaqi avatar Apr 30 '25 09:04 qiuhuaqi

I figured it out; The reshape removes spatial dimensions: source = source.reshape(source.shape[0], -1).float() will reshape to (batch,-1) The mean will collapse the second dimension, even if you set keepdim=True This is intended behavoure (print(torch.zeros((10, 10, 10, 10)).mean(1, keepdim=True).shape) will be (10,1,10,10) The shape is (batch) from that point onward.

Minimal example, shape when mask is used is 10.


ncc_loss(torch.zeros((10, 10, 10, 10)), torch.zeros((10, 10, 10, 10)), torch.zeros((10,1, 10, 10)))
def ncc_loss(
    source: Tensor,
    target: Tensor,
    mask: Optional[Tensor] = None,
    epsilon: float = 1e-15,
    reduction: str = "mean",
) -> Tensor:
    r"""Normalized cross correlation.

    Args:
        source: Source image sampled on ``target`` grid.
        target: Target image with same shape as ``source``.
        mask: Multiplicative mask tensor with same shape as ``source``.
        epsilon: Small constant added to denominator to avoid division by zero.
        reduction: Whether to compute "mean" or "sum" of normalized cross correlation
            of image pairs in batch. If "none", a 1-dimensional tensor is returned
            with length equal the batch size.

    Returns:
        Negative squared normalized cross correlation plus one.

    """

    if not isinstance(source, Tensor):
        raise TypeError("ncc_loss() 'source' must be tensor")
    if not isinstance(target, Tensor):
        raise TypeError("ncc_loss() 'target' must be tensor")
    if source.shape != target.shape:
        raise ValueError("ncc_loss() 'source' must have same shape as 'target'")

    source = source.reshape(source.shape[0], -1).float()
    target = target.reshape(source.shape[0], -1).float()
    print(f"{source.shape=},{target.shape=}")
    source_mean = source.mean(dim=1, keepdim=True)
    target_mean = target.mean(dim=1, keepdim=True)
    print(f"{source_mean.shape=},{target_mean.shape=}")
    x = source.sub(source_mean)
    y = target.sub(target_mean)
    print(f"{x.shape=},{y.shape=}")
    a = x.mul(y).sum(dim=1)
    b = x.square().sum(dim=1)
    c = y.square().sum(dim=1)
    print(f"{a.shape=},{b.shape=},{c.shape=}")

    loss = a.square_().div_(b.mul_(c).add_(epsilon)).neg_().add_(1)
    print(f"{loss.shape=},{mask.shape=}")
    loss = masked_loss(loss, mask, "ncc_loss")
    loss = reduce_loss(loss, reduction, mask)
    return loss
output:
source.shape=torch.Size([10, 1000]),target.shape=torch.Size([10, 1000])
source_mean.shape=torch.Size([10, 1]),target_mean.shape=torch.Size([10, 1])
x.shape=torch.Size([10, 1000]),y.shape=torch.Size([10, 1000])
a.shape=torch.Size([10]),b.shape=torch.Size([10]),c.shape=torch.Size([10])
loss.shape=torch.Size([10]),mask.shape=torch.Size([10, 10, 10, 10])

robert-graf avatar Apr 30 '25 13:04 robert-graf

Thanks for catching this. On first glance, it looks like I may have forgotten the batch dimension when adding this function.

aschuh-hf avatar Jul 19 '25 01:07 aschuh-hf