hub icon indicating copy to clipboard operation
hub copied to clipboard

Bug: Universal Sentence Encoder (USE) is not compatible with tf.distribute

Open jeisinge opened this issue 3 years ago • 0 comments

What happened?

The USE DAN model is an efficient embedder for short phrases. And, it trains well on a single GPU. However, it fails to train on a multi-GPU with tf.distribute.

A couple of previous defects have been reported and closed out, but the issue remains. The closest one I found is #515 . The workaround proposed by RobRomijnders is to utilize strategy.run(), however, I don't understand how to do this with Keras. Specifically, calling this method returned a PerReplica object --- I don't know how to merge it back to a regular Keras layer. See https://github.com/tensorflow/hub/issues/515#issuecomment-699928052 .

https://tfhub.dev/google/universal-sentence-encoder/4 The TF model claims to be TF2, however, the SavedModel states that it is TF 1.15. If it was TF2, I believe it wouldn't have an issue with tf.distribute.

https://colab.research.google.com/drive/1vgzBxzojLToHqR1XSGhmBepna1RfhXZG?usp=sharing Error colab notebook

Relevant code

# See https://colab.research.google.com/drive/1vgzBxzojLToHqR1XSGhmBepna1RfhXZG?usp=sharing

import tensorflow as tf
import tensorflow.keras as keras
import tensorflow_hub as hub

def make_use_4_model():
  inputs = keras.Input(shape=(), dtype=tf.dtypes.string, name="text_inputs")
  use = hub.KerasLayer("https://tfhub.dev/google/universal-sentence-encoder/4")
  outputs = use(inputs)
  return keras.Model(inputs, outputs, name="use")

strategy = tf.distribute.MirroredStrategy(["CPU:0", "CPU:1"]) # Fake multiple GPUs
with strategy.scope():
  distributed_encoder_4 = hub.KerasLayer(
    handle="https://tfhub.dev/google/universal-sentence-encoder/4",
  )

Relevant log output

7 frames

/usr/local/lib/python3.7/dist-packages/tensorflow/python/eager/execute.py in quick_execute(op_name, num_outputs, inputs, attrs, ctx, name)
     53     ctx.ensure_initialized()
     54     tensors = pywrap_tfe.TFE_Py_Execute(ctx._handle, device_name, op_name,
---> 55                                         inputs, attrs, num_outputs)
     56   except core._NotOkStatusException as e:
     57     if name is not None:

InvalidArgumentError: Graph execution error:

Detected at node 'Assert/Assert' defined at (most recent call last):
    File "/usr/lib/python3.7/runpy.py", line 193, in _run_module_as_main
      "__main__", mod_spec)
    File "/usr/lib/python3.7/runpy.py", line 85, in _run_code
      exec(code, run_globals)
    File "/usr/local/lib/python3.7/dist-packages/ipykernel_launcher.py", line 16, in <module>
      app.launch_new_instance()
    File "/usr/local/lib/python3.7/dist-packages/traitlets/config/application.py", line 846, in launch_instance
      app.start()
    File "/usr/local/lib/python3.7/dist-packages/ipykernel/kernelapp.py", line 499, in start
      self.io_loop.start()
    File "/usr/local/lib/python3.7/dist-packages/tornado/platform/asyncio.py", line 132, in start
      self.asyncio_loop.run_forever()
    File "/usr/lib/python3.7/asyncio/base_events.py", line 541, in run_forever
      self._run_once()
    File "/usr/lib/python3.7/asyncio/base_events.py", line 1786, in _run_once
      handle._run()
    File "/usr/lib/python3.7/asyncio/events.py", line 88, in _run
      self._context.run(self._callback, *self._args)
    File "/usr/local/lib/python3.7/dist-packages/tornado/platform/asyncio.py", line 122, in _handle_events
      handler_func(fileobj, events)
    File "/usr/local/lib/python3.7/dist-packages/tornado/stack_context.py", line 300, in null_wrapper
      return fn(*args, **kwargs)
    File "/usr/local/lib/python3.7/dist-packages/zmq/eventloop/zmqstream.py", line 577, in _handle_events
      self._handle_recv()
    File "/usr/local/lib/python3.7/dist-packages/zmq/eventloop/zmqstream.py", line 606, in _handle_recv
      self._run_callback(callback, msg)
    File "/usr/local/lib/python3.7/dist-packages/zmq/eventloop/zmqstream.py", line 556, in _run_callback
      callback(*args, **kwargs)
    File "/usr/local/lib/python3.7/dist-packages/tornado/stack_context.py", line 300, in null_wrapper
      return fn(*args, **kwargs)
    File "/usr/local/lib/python3.7/dist-packages/ipykernel/kernelbase.py", line 283, in dispatcher
      return self.dispatch_shell(stream, msg)
    File "/usr/local/lib/python3.7/dist-packages/ipykernel/kernelbase.py", line 233, in dispatch_shell
      handler(stream, idents, msg)
    File "/usr/local/lib/python3.7/dist-packages/ipykernel/kernelbase.py", line 399, in execute_request
      user_expressions, allow_stdin)
    File "/usr/local/lib/python3.7/dist-packages/ipykernel/ipkernel.py", line 208, in do_execute
      res = shell.run_cell(code, store_history=store_history, silent=silent)
    File "/usr/local/lib/python3.7/dist-packages/ipykernel/zmqshell.py", line 537, in run_cell
      return super(ZMQInteractiveShell, self).run_cell(*args, **kwargs)
    File "/usr/local/lib/python3.7/dist-packages/IPython/core/interactiveshell.py", line 2718, in run_cell
      interactivity=interactivity, compiler=compiler, result=result)
    File "/usr/local/lib/python3.7/dist-packages/IPython/core/interactiveshell.py", line 2822, in run_ast_nodes
      if self.run_code(code, result):
    File "/usr/local/lib/python3.7/dist-packages/IPython/core/interactiveshell.py", line 2882, in run_code
      exec(code_obj, self.user_global_ns, self.user_ns)
    File "<ipython-input-13-246cb39154ce>", line 3, in <module>
      handle="https://tfhub.dev/google/universal-sentence-encoder/4",
    File "/usr/local/lib/python3.7/dist-packages/tensorflow_hub/keras_layer.py", line 153, in __init__
      self._func = load_module(handle, tags, self._load_options)
    File "/usr/local/lib/python3.7/dist-packages/tensorflow_hub/keras_layer.py", line 449, in load_module
      return module_v2.load(handle, tags=tags, options=set_load_options)
    File "/usr/local/lib/python3.7/dist-packages/tensorflow_hub/module_v2.py", line 106, in load
      obj = tf.compat.v1.saved_model.load_v2(module_path, tags=tags)
Node: 'Assert/Assert'
assertion failed: [Trying to access a placeholder that is not supposed to be executed. This means you are executing a graph generated from the cross-replica context in an in-replica context.]
	 [[{{node Assert/Assert}}]] [Op:__inference_restored_function_body_56858]

tensorflow_hub Version

0.12.0 (latest stable release)

TensorFlow Version

2.8 (latest stable release)

Other libraries

No response

Python Version

3.x

OS

Linux

jeisinge avatar Jul 27 '22 23:07 jeisinge