`tensorflow_graphics.shape.check_static(...)` throwing error in TF2
I was writing a data augmentation layer for a PointNet implementation and ran into what appears to be a bug in tensorflow_graphics.shape.check_static(...), as seen on this line.
Offending layer:
class RandomRot(Layer):
def __init__(self):
super(RandomRot, self).__init__()
def build(self, input_shape):
self.s = tf.constant([input_shape[-1],])
def call(self, inputs, training=None):
if not training: return inputs
r = tf.random.uniform(
shape=self.s,
minval=0,
maxval=6.28,
)
return tf.linalg.matmul(inputs,from_euler(r))
Error message:
AttributeError: in user code:
<ipython-input-135-d11754641da6>:81 call *
self.x = self.r(self.x,training)
<ipython-input-130-07bfe7ac5ab9>:25 call *
return tf.linalg.matmul(inputs,from_euler(r))
/usr/local/lib/python3.6/dist-packages/tensorflow_graphics/geometry/transformation/rotation_matrix_3d.py:201 from_euler *
shape.check_static(
/usr/local/lib/python3.6/dist-packages/tensorflow_graphics/util/shape.py:206 check_static *
if _get_dim(tensor, axis) != value:
/usr/local/lib/python3.6/dist-packages/tensorflow_graphics/util/shape.py:135 _get_dim *
return tensor.shape[axis].value
AttributeError: 'int' object has no attribute 'value'
It appears that check_static is expecting each element from .shape to be a tensor, but in TF2 they're just ints. If I comment out check_static from from_euler, the function works fine. Strangely enough, it seems to work fine for tensors in eager execution, and only seems to throw errors when using Dataset objects with graph compilation.
Any update on this? Get the same error. Here is the simplest code that triggers the error.
quat = tf.constant([[0., 0., 0., 1.]], dtype=tf.float64)
euler = tfg.geometry.transformation.euler.from_quaternion(quat)
print(euler)
@tf.function
def rot(quat):
euler = tfg.geometry.transformation.euler.from_quaternion(quat)
print(rot(quat))
This triggers the same error as @m4ttr4ymond