probability icon indicating copy to clipboard operation
probability copied to clipboard

Model stuck when calling .fit(x, y) using negative binomial in DistributionLambda Layer

Open aegonwolf opened this issue 3 years ago • 1 comments

Hi all,

I have a simple BNN that I just tried to change to have a negative binomial distribution as output:

def get_model(input_shape, loss, optimizer, metrics, kl_weight, output_shape):
        
    inputs = Input(shape=(input_shape))
    x = BatchNormalization()(inputs)
    x = tfpl.DenseVariational(units=128, activation='tanh', make_posterior_fn=get_posterior, make_prior_fn=get_prior, kl_weight=kl_weight)(x)
    count = Dense(1)(x)
    logits = Dense(output_shape, activation = 'sigmoid')(x)
    neg_binom = tfp.layers.DistributionLambda(
            lambda t: tfd.NegativeBinomial(total_count=t[..., 0:1], probs = t[..., 1:]))
    cat = Concatenate(axis=-1)([count, logits])
    outputs = neg_binom(cat)
    model = Model(inputs, outputs)
    model.compile(optimizer=optimizer, loss=loss, metrics=metrics)
    return model

I do not get an error, it compiles and when I call model.fit(x,y) I just get:

Epoch 1/500

and it's stuck here forever (about 20 minutes I waited for the longest).

When I use a Poisson Layer, which I did before it starts fitting instantly, an epoch runs about 1s.

What could be the cause of this? Is there something wrong with my code above? I was hoping to call param_size but distribution lambda seems not to support this (just in case I am missing something).

If I use a single Dense Layer without concatenate and just linear activation I get the same behavior.

Many thanks for your insights and tips on things to try and debug this behavior.

aegonwolf avatar Dec 23 '22 16:12 aegonwolf

For more information it is stuck here: file: traceback_utils.py in this function

def filter_traceback(fn):
  """Filter out Keras-internal stack trace frames in exceptions raised by fn."""
  if sys.version_info.major != 3 or sys.version_info.minor < 7:
    return fn

  def error_handler(*args, **kwargs):
    if not tf.debugging.is_traceback_filtering_enabled():
      return fn(*args, **kwargs)

    filtered_tb = None
    try:
      return fn(*args, **kwargs)
    except Exception as e:  # pylint: disable=broad-except
      filtered_tb = _process_traceback_frames(e.__traceback__)
      raise e.with_traceback(filtered_tb) from None
    finally:
      del filtered_tb

here: try: return fn(*args, **kwargs)

disabling eager execution gets me until the first backward step.

aegonwolf avatar Dec 23 '22 19:12 aegonwolf