MONAI icon indicating copy to clipboard operation
MONAI copied to clipboard

MaisiVAE: Auto-cast GroupNorm, deprecate norm_float16

Open johnzielke opened this issue 1 year ago • 10 comments

Description

The current maisi vae encoder only supports float32 and float16 with a parameter that needs to be set manually. This PR instead infers the norm datatype from the input and should therefore enable the use of other datatypes

Types of changes

  • [x] Non-breaking change (fix or new feature that would not break existing functionality).
  • [ ] Breaking change (fix or new feature that would cause existing functionality to change).
  • [x] New tests added to cover the changes.
  • [ ] Integration tests passed locally by running ./runtests.sh -f -u --net --coverage.
  • [ ] Quick tests passed locally by running ./runtests.sh --quick --unittests --disttests.
  • [ ] In-line docstrings updated.
  • [ ] Documentation updated, tested make html command in the docs/ folder.

johnzielke avatar Feb 04 '25 14:02 johnzielke

Hi @johnzielke, could you please help resolve the conflict then I cane help trigger the blossom, thanks!

KumoLiu avatar Apr 13 '25 08:04 KumoLiu

/build

KumoLiu avatar Apr 21 '25 04:04 KumoLiu

Hi @dongyang0122, could you please help check whether this change make sense to you? Just want confirm there is no other specific concern regarding this norm_float16 param.

KumoLiu avatar Apr 21 '25 04:04 KumoLiu

/build

KumoLiu avatar Apr 22 '25 06:04 KumoLiu

Hi, are there any updates on this PR?

johnzielke avatar Apr 30 '25 14:04 johnzielke

Hi, are there any updates on this PR?

Hello @johnzielke, sorry for the late response.

I discussed offline with @dongyang0122. The addition of norm_float16 is intended to prevent out-of-memory (OOM) issues during inference, as this has been a bottleneck for MAISI. The rationale for not using float16 previously was to avoid affecting the precision of other layers and thus prevent truncation errors. Therefore, I suggest that we retain this argument. What do you think?

KumoLiu avatar Apr 30 '25 15:04 KumoLiu

Thank you for the info. It is surprising to me that this deceptively simple operation would be a problem during inference compared to attention or other operations. But without going into that, that brings up some other questions: I think the reason that I created this PR is that if the return type of the group norm is float16, different from the rest of the network, I recall this creates a problem with the convolution later on as these will have a float32 datatype. (this is after converting the model to a different datatype such as bfloat16 using vae_model.to(torch.bfloat16)

Without going into details why the groupnorm is so memory-intensive, my other suggestion is to make the norm_float16 parameter accept a dtype as well, so that you can have it use float16, bfloat16 or any of the other dtypes.

The current implementation makes it impossible/hard to use this model with any of the new high-performance datatypes introduced in new Nvidia GPUs, which would is very unfortunate in my opinion. This is how I stumbled on this problem, as it is not mentioned in the documentation that this use case is not supported.

johnzielke avatar Apr 30 '25 16:04 johnzielke

my other suggestion is to make the norm_float16 parameter accept a dtype as well, so that you can have it use float16, bfloat16 or any of the other dtypes.

Yes, sure. It totally make sense and make it more easy to be compatible with more use cases. Could you please help modify it in this pr and also backward compatible with current use case? Thanks in advance!

KumoLiu avatar Apr 30 '25 16:04 KumoLiu

Yes. I'm not going to change the naming then, which will mean the parameter name will be confusing, but I think it keeps the amount of change and deprecation to a minimum. I suggest: norm_float16: bool | str | None bool: Current behavior str: torch dtype to convert to None: The behavior suggested initially in the PR, where it will be whatever dtype the input variable is

johnzielke avatar Apr 30 '25 17:04 johnzielke

For my education: Why does the current implementation require so much less memory than a simple:

  param_n, param_c, param_d, param_h, param_w = input.shape
  input_g = input.view(param_n, self.num_groups, param_c // self.num_groups, param_d, param_h, param_w)

  mean = input_g.mean([2, 3, 4, 5], keepdim=True)
  std = input_g.var([2, 3, 4, 5], unbiased=False, keepdim=True).add_(self.eps).sqrt_()
  return (input - mean) / std

The mean and std tensors should be quite small, and the resulting tensor is of the same size, so why is the loop implementation more memory-efficient?

johnzielke avatar Apr 30 '25 21:04 johnzielke