oneDNN icon indicating copy to clipboard operation
oneDNN copied to clipboard

bit mask support for dropout

Open eric-haibin-lin opened this issue 5 years ago • 4 comments

For dropout training, one can save the dropout mask with 1 bit per coordinate. Can we support that in DNNL? Memory is precious.

eric-haibin-lin avatar Feb 27 '20 04:02 eric-haibin-lin

Hi @eric-haibin-lin,

Thank you for your question. Technically nothing prevents us from introducing dropout primitive in the library, including the 1-bit mask support. The main question we need to answer to make this happen is what API and behavior should look like to make the functionality generally useful. For dropout the main source of concern is the fact that it relies on random number generator, which may behave differently in different applications, so having random number generator as part of implementation would be a major source of incompatibility and thread safety issues.

A couple of follow up questions so that I can better understand what you are looking for:

  • What you expect from the DNNL implementation (vs implementing the functionality directly in C++)?
  • What API will make sense to you? Is a function that takes pre-computed mask and performs dropout viable?

vpirogov avatar Feb 28 '20 17:02 vpirogov

I think the random number generator is taking a significant part in the execution of dropout. That's why we optimized it in MXNet with viRngBernoulli from VSL. But now viRngBernoulli cannot meet the requirement anymore for bit mask generation. So I would expect DNNL covers the RNG part and then the forward interface should look like:

  • input: source data, random seed, distribution type, mask type, p value
  • output: destination data, workspace for the mask

in which, the mask type can be bit mask, boolean mask, or integer mask. I'm not sure if boolean mask or integer mask has any advantage but they're used in frameworks.

@eric-haibin-lin @apeforest Could you please share more insights about the random seed distribution in MXNet and reproducibility of the operator?

TaoLv avatar Mar 02 '20 06:03 TaoLv

The random seed should be taken from MXNet so that if user specify a random seed in mxnet it should guarantee reproducibility. A similar approach has been done for cuDNN library: https://github.com/apache/incubator-mxnet/pull/17547

apeforest avatar Mar 04 '20 07:03 apeforest

I notice there is a RFC opened for this request. You may want to take a look. @eric-haibin-lin @apeforest @pengzhao-intel

TaoLv avatar Jun 20 '20 02:06 TaoLv