[Relax] Fix a bug that occurred due to shape inference not handling static dim vs symbolic dim
Summary
This PR fixes https://github.com/apache/tvm/issues/17964 that occurs through the following process:
- The Compress node produces a symbolic dimension "num_nonzero"
- The
relax.op.add()in the BiasGelu node adds shape=[2, num_nonzero] and shape=[3] - The current binary shape inference does not take into account
static vs symbolicdimension cases and produces "None" shape, which leads to the error.
Changes
- Add comments on each case
- Add support for comparing static and symbolic dimensions in shape inference
Notes
The current binary shape inference takes into account the following cases
| Case | Example | expected output dim |
|---|---|---|
| static dim(1) | (2, 3) + (1, 3) | (2, 3) |
| equal static dims | (2, 3) + (2, 3) | (2, 3) |
| equal symbolic dims | (n, m) + (n, m) | (n, m) |
However, it does not take into account the following cases
| Case | Example | expected output dim |
|---|---|---|
| static dim vs symbolic dim | (2, 3) + (2, n) | (2, 3) because the 2nd dim must be the static dim(3) regardless of the symbolic dim(n) |
| different symbolic dims | (2, n) + (2, m) | (2, n) or (2, m) because the output dim cannot be determined at compile time |
The static dim vs symbolic dim case can be determined at compile time.
I disagree with the premise that (2, 3) + (2, n) can be statically inferred to have the shape (2, 3). The validity of this operation depends on the runtime value of n.
Specifically:
- If
n=1orn=3, the operation is valid due to broadcasting, and the output shape is indeed(2, 3). - If
nis any other value, the shapes are incompatible, and the operation should raise a runtime error.
To handle this correctly, we need a mechanism for runtime shape validation. I suggest the following steps:
- Assume the valid shape: For this specific operation, we can tentatively define the output shape as
(2, 3)(as this pull request does), since it's the only possible shape for a valid operation. - Insert a runtime check: We MUST add an assertion to the model to verify at runtime that
nis either 1 or 3. - (Optional) Hoist the checks: We could write an optimization pass to lift all such runtime assertions to the beginning of the model and potentially combine them for efficiency.
cc @tqchen
@Hzfengsy Thanks for your feedback.
I'll review the mechanism you proposed, try to implement it, and get back to you.
I inspected operator implementations in the source code and couldn't find any precedent for asserting or validating dynamic tensors and raising a runtime error. I’m afraid I won’t be able to implement your suggestion until such a precedent exists given my currently limited experience with TVM. I'll try to implement it in the future.