tvm icon indicating copy to clipboard operation
tvm copied to clipboard

[Relax] Fix a bug that occurred due to shape inference not handling static dim vs symbolic dim

Open vacu9708 opened this issue 7 months ago • 3 comments

Summary

This PR fixes https://github.com/apache/tvm/issues/17964 that occurs through the following process:

image

  1. The Compress node produces a symbolic dimension "num_nonzero"
  2. The relax.op.add() in the BiasGelu node adds shape=[2, num_nonzero] and shape=[3]
  3. The current binary shape inference does not take into account static vs symbolic dimension cases and produces "None" shape, which leads to the error.

Changes

  • Add comments on each case
  • Add support for comparing static and symbolic dimensions in shape inference

Notes

The current binary shape inference takes into account the following cases

Case Example expected output dim
static dim(1) (2, 3) + (1, 3) (2, 3)
equal static dims (2, 3) + (2, 3) (2, 3)
equal symbolic dims (n, m) + (n, m) (n, m)

However, it does not take into account the following cases

Case Example expected output dim
static dim vs symbolic dim (2, 3) + (2, n) (2, 3) because the 2nd dim must be the static dim(3) regardless of the symbolic dim(n)
different symbolic dims (2, n) + (2, m) (2, n) or (2, m) because the output dim cannot be determined at compile time

The static dim vs symbolic dim case can be determined at compile time.

vacu9708 avatar Jul 06 '25 14:07 vacu9708

I disagree with the premise that (2, 3) + (2, n) can be statically inferred to have the shape (2, 3). The validity of this operation depends on the runtime value of n.

Specifically:

  • If n=1 or n=3, the operation is valid due to broadcasting, and the output shape is indeed (2, 3).
  • If n is any other value, the shapes are incompatible, and the operation should raise a runtime error.

To handle this correctly, we need a mechanism for runtime shape validation. I suggest the following steps:

  1. Assume the valid shape: For this specific operation, we can tentatively define the output shape as (2, 3) (as this pull request does), since it's the only possible shape for a valid operation.
  2. Insert a runtime check: We MUST add an assertion to the model to verify at runtime that n is either 1 or 3.
  3. (Optional) Hoist the checks: We could write an optimization pass to lift all such runtime assertions to the beginning of the model and potentially combine them for efficiency.

cc @tqchen

Hzfengsy avatar Jul 07 '25 05:07 Hzfengsy

@Hzfengsy Thanks for your feedback.

I'll review the mechanism you proposed, try to implement it, and get back to you.

vacu9708 avatar Jul 07 '25 06:07 vacu9708

I inspected operator implementations in the source code and couldn't find any precedent for asserting or validating dynamic tensors and raising a runtime error. I’m afraid I won’t be able to implement your suggestion until such a precedent exists given my currently limited experience with TVM. I'll try to implement it in the future.

vacu9708 avatar Jul 13 '25 10:07 vacu9708