graphics icon indicating copy to clipboard operation
graphics copied to clipboard

`tensorflow_graphics.shape.check_static(...)` throwing error in TF2

Open m4ttr4ymond opened this issue 4 years ago • 1 comments

I was writing a data augmentation layer for a PointNet implementation and ran into what appears to be a bug in tensorflow_graphics.shape.check_static(...), as seen on this line.

Offending layer:

class RandomRot(Layer):
  def __init__(self):
    super(RandomRot, self).__init__()

  def build(self, input_shape):
    self.s = tf.constant([input_shape[-1],])

  def call(self, inputs, training=None):
    if not training: return inputs
    
    r = tf.random.uniform(
      shape=self.s,
      minval=0,
      maxval=6.28,
    )

    return tf.linalg.matmul(inputs,from_euler(r))

Error message:

AttributeError: in user code:

    <ipython-input-135-d11754641da6>:81 call  *
        self.x = self.r(self.x,training)
    <ipython-input-130-07bfe7ac5ab9>:25 call  *
        return tf.linalg.matmul(inputs,from_euler(r))
    /usr/local/lib/python3.6/dist-packages/tensorflow_graphics/geometry/transformation/rotation_matrix_3d.py:201 from_euler  *
        shape.check_static(
    /usr/local/lib/python3.6/dist-packages/tensorflow_graphics/util/shape.py:206 check_static  *
        if _get_dim(tensor, axis) != value:
    /usr/local/lib/python3.6/dist-packages/tensorflow_graphics/util/shape.py:135 _get_dim  *
        return tensor.shape[axis].value

    AttributeError: 'int' object has no attribute 'value'

It appears that check_static is expecting each element from .shape to be a tensor, but in TF2 they're just ints. If I comment out check_static from from_euler, the function works fine. Strangely enough, it seems to work fine for tensors in eager execution, and only seems to throw errors when using Dataset objects with graph compilation.

m4ttr4ymond avatar Jul 25 '21 20:07 m4ttr4ymond

Any update on this? Get the same error. Here is the simplest code that triggers the error.

quat = tf.constant([[0., 0., 0., 1.]], dtype=tf.float64)

euler = tfg.geometry.transformation.euler.from_quaternion(quat)
print(euler)

@tf.function
def rot(quat):
    euler = tfg.geometry.transformation.euler.from_quaternion(quat)
print(rot(quat))

This triggers the same error as @m4ttr4ymond

NicolayP avatar Mar 04 '22 17:03 NicolayP