Bug: Universal Sentence Encoder (USE) is not compatible with tf.distribute
What happened?
The USE DAN model is an efficient embedder for short phrases. And, it trains well on a single GPU. However, it fails to train on a multi-GPU with tf.distribute.
A couple of previous defects have been reported and closed out, but the issue remains. The closest one I found is #515 . The workaround proposed by RobRomijnders is to utilize strategy.run(), however, I don't understand how to do this with Keras. Specifically, calling this method returned a PerReplica object --- I don't know how to merge it back to a regular Keras layer. See https://github.com/tensorflow/hub/issues/515#issuecomment-699928052 .
https://tfhub.dev/google/universal-sentence-encoder/4
The TF model claims to be TF2, however, the SavedModel states that it is TF 1.15. If it was TF2, I believe it wouldn't have an issue with tf.distribute.
https://colab.research.google.com/drive/1vgzBxzojLToHqR1XSGhmBepna1RfhXZG?usp=sharing Error colab notebook
Relevant code
# See https://colab.research.google.com/drive/1vgzBxzojLToHqR1XSGhmBepna1RfhXZG?usp=sharing
import tensorflow as tf
import tensorflow.keras as keras
import tensorflow_hub as hub
def make_use_4_model():
inputs = keras.Input(shape=(), dtype=tf.dtypes.string, name="text_inputs")
use = hub.KerasLayer("https://tfhub.dev/google/universal-sentence-encoder/4")
outputs = use(inputs)
return keras.Model(inputs, outputs, name="use")
strategy = tf.distribute.MirroredStrategy(["CPU:0", "CPU:1"]) # Fake multiple GPUs
with strategy.scope():
distributed_encoder_4 = hub.KerasLayer(
handle="https://tfhub.dev/google/universal-sentence-encoder/4",
)
Relevant log output
7 frames
/usr/local/lib/python3.7/dist-packages/tensorflow/python/eager/execute.py in quick_execute(op_name, num_outputs, inputs, attrs, ctx, name)
53 ctx.ensure_initialized()
54 tensors = pywrap_tfe.TFE_Py_Execute(ctx._handle, device_name, op_name,
---> 55 inputs, attrs, num_outputs)
56 except core._NotOkStatusException as e:
57 if name is not None:
InvalidArgumentError: Graph execution error:
Detected at node 'Assert/Assert' defined at (most recent call last):
File "/usr/lib/python3.7/runpy.py", line 193, in _run_module_as_main
"__main__", mod_spec)
File "/usr/lib/python3.7/runpy.py", line 85, in _run_code
exec(code, run_globals)
File "/usr/local/lib/python3.7/dist-packages/ipykernel_launcher.py", line 16, in <module>
app.launch_new_instance()
File "/usr/local/lib/python3.7/dist-packages/traitlets/config/application.py", line 846, in launch_instance
app.start()
File "/usr/local/lib/python3.7/dist-packages/ipykernel/kernelapp.py", line 499, in start
self.io_loop.start()
File "/usr/local/lib/python3.7/dist-packages/tornado/platform/asyncio.py", line 132, in start
self.asyncio_loop.run_forever()
File "/usr/lib/python3.7/asyncio/base_events.py", line 541, in run_forever
self._run_once()
File "/usr/lib/python3.7/asyncio/base_events.py", line 1786, in _run_once
handle._run()
File "/usr/lib/python3.7/asyncio/events.py", line 88, in _run
self._context.run(self._callback, *self._args)
File "/usr/local/lib/python3.7/dist-packages/tornado/platform/asyncio.py", line 122, in _handle_events
handler_func(fileobj, events)
File "/usr/local/lib/python3.7/dist-packages/tornado/stack_context.py", line 300, in null_wrapper
return fn(*args, **kwargs)
File "/usr/local/lib/python3.7/dist-packages/zmq/eventloop/zmqstream.py", line 577, in _handle_events
self._handle_recv()
File "/usr/local/lib/python3.7/dist-packages/zmq/eventloop/zmqstream.py", line 606, in _handle_recv
self._run_callback(callback, msg)
File "/usr/local/lib/python3.7/dist-packages/zmq/eventloop/zmqstream.py", line 556, in _run_callback
callback(*args, **kwargs)
File "/usr/local/lib/python3.7/dist-packages/tornado/stack_context.py", line 300, in null_wrapper
return fn(*args, **kwargs)
File "/usr/local/lib/python3.7/dist-packages/ipykernel/kernelbase.py", line 283, in dispatcher
return self.dispatch_shell(stream, msg)
File "/usr/local/lib/python3.7/dist-packages/ipykernel/kernelbase.py", line 233, in dispatch_shell
handler(stream, idents, msg)
File "/usr/local/lib/python3.7/dist-packages/ipykernel/kernelbase.py", line 399, in execute_request
user_expressions, allow_stdin)
File "/usr/local/lib/python3.7/dist-packages/ipykernel/ipkernel.py", line 208, in do_execute
res = shell.run_cell(code, store_history=store_history, silent=silent)
File "/usr/local/lib/python3.7/dist-packages/ipykernel/zmqshell.py", line 537, in run_cell
return super(ZMQInteractiveShell, self).run_cell(*args, **kwargs)
File "/usr/local/lib/python3.7/dist-packages/IPython/core/interactiveshell.py", line 2718, in run_cell
interactivity=interactivity, compiler=compiler, result=result)
File "/usr/local/lib/python3.7/dist-packages/IPython/core/interactiveshell.py", line 2822, in run_ast_nodes
if self.run_code(code, result):
File "/usr/local/lib/python3.7/dist-packages/IPython/core/interactiveshell.py", line 2882, in run_code
exec(code_obj, self.user_global_ns, self.user_ns)
File "<ipython-input-13-246cb39154ce>", line 3, in <module>
handle="https://tfhub.dev/google/universal-sentence-encoder/4",
File "/usr/local/lib/python3.7/dist-packages/tensorflow_hub/keras_layer.py", line 153, in __init__
self._func = load_module(handle, tags, self._load_options)
File "/usr/local/lib/python3.7/dist-packages/tensorflow_hub/keras_layer.py", line 449, in load_module
return module_v2.load(handle, tags=tags, options=set_load_options)
File "/usr/local/lib/python3.7/dist-packages/tensorflow_hub/module_v2.py", line 106, in load
obj = tf.compat.v1.saved_model.load_v2(module_path, tags=tags)
Node: 'Assert/Assert'
assertion failed: [Trying to access a placeholder that is not supposed to be executed. This means you are executing a graph generated from the cross-replica context in an in-replica context.]
[[{{node Assert/Assert}}]] [Op:__inference_restored_function_body_56858]
tensorflow_hub Version
0.12.0 (latest stable release)
TensorFlow Version
2.8 (latest stable release)
Other libraries
No response
Python Version
3.x
OS
Linux