lingvo icon indicating copy to clipboard operation
lingvo copied to clipboard

Feature request: lingvo.jax.asserts.HasShape

Open drpngx opened this issue 2 years ago • 1 comments

I tried

def AssertShape(x: jnp.array, shape) -> None:
  if not jnp.array_equal(x.shape, shape):
    raise ValueError(f'Shape mismatch: found {x.shape}, expected: {shape}')

and got

jax.errors.TracerBoolConversionError: Attempted boolean conversion of traced array with shape bool[]..

(BTW, note the double period)

drpngx avatar Sep 08 '23 15:09 drpngx

To answer my own question:

def AssertShape(x: jnp.array, expected_shape) -> None:
  # Shape must be a python element with static size.
  xs = x.shape
  if len(xs) != len(expected_shape):
    raise ValueError(
      f'Wrong rank: got [{len(xs)}] {x.shape}, expected [{len(expected_shape)}] {expected_shape}')
  for k, d in enumerate(expected_shape):
    if xs[k] != d:
      raise ValueError(f'Wrong shape at dim {k}: got {x.shape}, expected {expected_shape}')

drpngx avatar Sep 10 '23 16:09 drpngx