mcfly
mcfly copied to clipboard
find_best_architecture fails with tf.keras.metrics objects
When giving a tf.keras.metrics object as metric in find_best_architecture(), e.g., find_best_architecture(..., metric=tf.keras.metrics.Precision()), an error is raised:
models = modelgen.generate_models(X_train.shape, y_train.shape[1],
number_of_models=number_of_models,
task=task,
metrics=[metric],
**kwargs)
_, val_performance, _ = train_models_on_samples(X_train,
y_train,
X_val,
(...)
model_path=model_path,
class_weight=class_weight)
--> best_model_index = np.argmax(val_performance[metric])
ValueError: attempt to get argmax of an empty sequence
The error occurs because find_best_architecture() stores the metric scores on the validation set in the dictionary val_performance and uses the metric string as a key. When supplying a keras metrics object, an empty list is returned because the dictionary has no corresponding metric string.