lingvo
lingvo copied to clipboard
Feature request: lingvo.jax.asserts.HasShape
I tried
def AssertShape(x: jnp.array, shape) -> None:
if not jnp.array_equal(x.shape, shape):
raise ValueError(f'Shape mismatch: found {x.shape}, expected: {shape}')
and got
jax.errors.TracerBoolConversionError: Attempted boolean conversion of traced array with shape bool[]..
(BTW, note the double period)
To answer my own question:
def AssertShape(x: jnp.array, expected_shape) -> None:
# Shape must be a python element with static size.
xs = x.shape
if len(xs) != len(expected_shape):
raise ValueError(
f'Wrong rank: got [{len(xs)}] {x.shape}, expected [{len(expected_shape)}] {expected_shape}')
for k, d in enumerate(expected_shape):
if xs[k] != d:
raise ValueError(f'Wrong shape at dim {k}: got {x.shape}, expected {expected_shape}')