java icon indicating copy to clipboard operation
java copied to clipboard

Distributed Training with TensorFlow Java

Open danilojsl opened this issue 4 years ago • 8 comments

Please make sure that this is a feature request. As per our GitHub Policy, we only address code/doc bugs, performance issues, feature requests and build/installation issues on GitHub. tag:feature_template

System information

  • TensorFlow version (you are using): 2.X
  • Are you willing to contribute it (Yes/No): Yes, when able and available

Describe the feature and the current behavior/state. Tensorflow on Python has tf.distribute.Strategy API to distribute training across multiple GPUs or multiple machines.

Will this change the current api? How? Yes, it will add a new awesome feature

Who will benefit with this feature?

  • Anyone that requires to speed up training a DL model
  • Anyone that requires to train a DL model with big data
  • Anyone who wants to create or add Java support for APIs that leverages tf.distribute.Strategy such as TensorflowOnSpark, Spark Tensorflow Distributor or Horovod

Any Other info. https://www.tensorflow.org/guide/distributed_training

danilojsl avatar Aug 19 '21 14:08 danilojsl

We had a brief look into this several months ago, and it's basically a huge pile of Python (https://github.com/tensorflow/tensorflow/tree/master/tensorflow/python/distribute) we'd need to replicate into Java. Doing out-of-band distribution using MPI allreduce directly from Java would be relatively easy, though have much worse performance than the native TF solution as it wouldn't do direct GPU - GPU copies.

Craigacp avatar Aug 19 '21 14:08 Craigacp

Hi @Craigacp

Is this also true for inference? In prediction, each inference is isolated from the others so it seems easier to batch inputs and send them to multiple GPU devices at the same time in parallel. (just trying to see if inference over multiple GPU devices can happen in tensorflow-java)

maziyarpanahi avatar Oct 26 '21 19:10 maziyarpanahi

Most of the ops are there (https://github.com/tensorflow/java/tree/55547dd20b14e1e9cd592a8789e780a0be3ae507/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/collective), I'm not sure if you can use them manually yet (we have no way to create groups or instances) or whether they can send gpu <-> gpu (w/ nvcc I assume).

Of course, if you have multiple GPUs on a single machine, you can just use the device settings.

rnett avatar Oct 26 '21 19:10 rnett

Thanks @rnett

We should do some testings, I wasn't sure if I could load something like BERT (which the ops and device assignment are out of my hand and it's a SavedModel) and somehow use ConfigProto/Session to distribute over multiple GPU devices.

I'll see if the device settings/scope can be applied to a loaded SavedModel.

maziyarpanahi avatar Oct 26 '21 19:10 maziyarpanahi

That's more a TF-core thing, I don't know if there's support for it, although it seems like a common enough use case. If you don't find a way, you may be able to create an ConcreteFunction out of your inference call, and then call that on different GPUs.

rnett avatar Oct 26 '21 19:10 rnett

I'll give it a shot to see if I can send each partition of inputs on a different available GPU device even in a simple round-robin can help.

Thanks again @rnett, since the explosion of pretrained models for TF this may become a feature in tensorflow-java one day

maziyarpanahi avatar Oct 26 '21 19:10 maziyarpanahi

Is this also true for inference? In prediction, each inference is isolated from the others so it seems easier to batch inputs and send them to multiple GPU devices at the same time in parallel. (just trying to see if inference over multiple GPU devices can happen in tensorflow-java)

I wonder how easily we can do that with a proper inference server like Triton, whose efficient-but-not-too-user-friendly-yet C API can be used from Java: https://github.com/bytedeco/javacpp-presets/tree/master/tritonserver

@jackyh What do you think?

saudet avatar Oct 26 '21 23:10 saudet

If we're talking about different packages you can use ONNX Runtime sessions one per GPU in the same JVM. But let's get it working properly in TF-Java.

A while ago there was a compilation issue when building TF-Java on multiple GPUs as the device placement algorithm got a bit confused by some of the optimisers. I guess I should check if that's still true now we've upgraded TF multiple times.

Craigacp avatar Oct 27 '21 13:10 Craigacp