agents icon indicating copy to clipboard operation
agents copied to clipboard

tensor_spec.sample_spec_nest can't deal with dtype tf.bool

Open RR-28023 opened this issue 4 years ago • 1 comments

The function sample_spec_nest currently raises a TypeError if any of the specs has dtype bool. For example, the below code:

import tensorflow as tf
from tf_agents.specs import tensor_spec

spec = TensorSpec(shape=(1, 40), dtype=tf.bool, name='mask')
random_input = tensor_spec.sample_spec_nest(spec)

Raises the error:

TypeError: Cannot find minimum value of <dtype: 'bool'>.

sample_spec_nest already accounts for the case when spec.dtype == tf.string, so I think it would make sense to add another exception for the tf.bool case.

My code is calling sample_spec_nest when calling Network.create_variables() for a custom Network that takes a boolean tensor as one of its inputs, which I guess is a common thing to do (e.g. for masks) and hence would be nice if it's supported.

RR-28023 avatar Jun 30 '21 17:06 RR-28023

Feel free to send PR to fix this.

sguada avatar Jul 26 '21 20:07 sguada