Model stuck when calling .fit(x, y) using negative binomial in DistributionLambda Layer
Hi all,
I have a simple BNN that I just tried to change to have a negative binomial distribution as output:
def get_model(input_shape, loss, optimizer, metrics, kl_weight, output_shape):
inputs = Input(shape=(input_shape))
x = BatchNormalization()(inputs)
x = tfpl.DenseVariational(units=128, activation='tanh', make_posterior_fn=get_posterior, make_prior_fn=get_prior, kl_weight=kl_weight)(x)
count = Dense(1)(x)
logits = Dense(output_shape, activation = 'sigmoid')(x)
neg_binom = tfp.layers.DistributionLambda(
lambda t: tfd.NegativeBinomial(total_count=t[..., 0:1], probs = t[..., 1:]))
cat = Concatenate(axis=-1)([count, logits])
outputs = neg_binom(cat)
model = Model(inputs, outputs)
model.compile(optimizer=optimizer, loss=loss, metrics=metrics)
return model
I do not get an error, it compiles and when I call model.fit(x,y) I just get:
Epoch 1/500
and it's stuck here forever (about 20 minutes I waited for the longest).
When I use a Poisson Layer, which I did before it starts fitting instantly, an epoch runs about 1s.
What could be the cause of this? Is there something wrong with my code above? I was hoping to call param_size but distribution lambda seems not to support this (just in case I am missing something).
If I use a single Dense Layer without concatenate and just linear activation I get the same behavior.
Many thanks for your insights and tips on things to try and debug this behavior.
For more information it is stuck here:
file: traceback_utils.py in this function
def filter_traceback(fn):
"""Filter out Keras-internal stack trace frames in exceptions raised by fn."""
if sys.version_info.major != 3 or sys.version_info.minor < 7:
return fn
def error_handler(*args, **kwargs):
if not tf.debugging.is_traceback_filtering_enabled():
return fn(*args, **kwargs)
filtered_tb = None
try:
return fn(*args, **kwargs)
except Exception as e: # pylint: disable=broad-except
filtered_tb = _process_traceback_frames(e.__traceback__)
raise e.with_traceback(filtered_tb) from None
finally:
del filtered_tb
here:
try: return fn(*args, **kwargs)
disabling eager execution gets me until the first backward step.