agents
agents copied to clipboard
tensor_spec.sample_spec_nest can't deal with dtype tf.bool
The function sample_spec_nest currently raises a TypeError if any of the specs has dtype bool. For example, the below code:
import tensorflow as tf
from tf_agents.specs import tensor_spec
spec = TensorSpec(shape=(1, 40), dtype=tf.bool, name='mask')
random_input = tensor_spec.sample_spec_nest(spec)
Raises the error:
TypeError: Cannot find minimum value of <dtype: 'bool'>.
sample_spec_nest already accounts for the case when spec.dtype == tf.string, so I think it would make sense to add another exception for the tf.bool case.
My code is calling sample_spec_nest when calling Network.create_variables() for a custom Network that takes a boolean tensor as one of its inputs, which I guess is a common thing to do (e.g. for masks) and hence would be nice if it's supported.
Feel free to send PR to fix this.