tfjs icon indicating copy to clipboard operation
tfjs copied to clipboard

TFJS reduction ops do not support zero shape tensors as TF

Open pyu10055 opened this issue 3 years ago • 0 comments

Please make sure that this is a bug. As per our GitHub Policy, we only address code/doc bugs, performance issues, feature requests and build/installation issues on GitHub. tag:bug_template

System information

  • Have I written custom code (as opposed to using a stock example script provided in TensorFlow.js):
  • OS Platform and Distribution (e.g., Linux Ubuntu 16.04): All
  • Mobile device (e.g. iPhone 8, Pixel 2, Samsung Galaxy) if the issue happens on mobile device: Yes
  • TensorFlow.js installed from (npm or script link): Both
  • TensorFlow.js version (use command below): 3.18
  • Browser version: Latest
  • Tensorflow.js Converter Version: 3.18

Describe the current behavior

The reduction ops [All, Any, Min, Max, Mean, Prod, Sum] do not support zero shaped tensors, in TF those op reduces on the axis and keep the other dimension the same, and has an initial value for each op:

#bool input
>>> x = tf.constant([], shape=[0], dtype=tf.bool)
>>> tf.raw_ops.Any(input=x, axis=0, keep_dims=False).numpy()
False
>>> tf.raw_ops.All(input=x, axis=0, keep_dims=False).numpy()
True

#float32 input
>>> x = tf.constant([], shape=[0], dtype=tf.float32)
>>> tf.raw_ops.Mean(input=x, axis=0, keep_dims=False).numpy()
nan
>>> tf.raw_ops.Max(input=x, axis=0, keep_dims=False).numpy()
-inf
>>> tf.raw_ops.Min(input=x, axis=0, keep_dims=False).numpy()
inf
>>> tf.raw_ops.Prod(input=x, axis=0, keep_dims=False).numpy()
1.0
>>> tf.raw_ops.Sum(input=x, axis=0, keep_dims=False).numpy()
0.0

#multiple dimensions
>>> x = tf.constant([], shape=[0, 2], dtype=tf.bool)
>>> tf.raw_ops.All(input=x, axis=0, keep_dims=False).numpy()
array([ True,  True])
>>> x = tf.constant([], shape=[2, 0], dtype=tf.bool)
>>> tf.raw_ops.All(input=x, axis=0, keep_dims=False).numpy()
array([], dtype=bool)

Describe the expected behavior Match TF python results and do not throw errors.

ref https://github.com/tensorflow/tfjs/issues/6605

pyu10055 avatar Jul 19 '22 23:07 pyu10055