mcfly icon indicating copy to clipboard operation
mcfly copied to clipboard

find_best_architecture fails with tf.keras.metrics objects

Open maltelueken opened this issue 3 years ago • 0 comments

When giving a tf.keras.metrics object as metric in find_best_architecture(), e.g., find_best_architecture(..., metric=tf.keras.metrics.Precision()), an error is raised:

models = modelgen.generate_models(X_train.shape, y_train.shape[1],
                                   number_of_models=number_of_models,
                                   task=task,
                                   metrics=[metric],
                                   **kwargs)
 _, val_performance, _ = train_models_on_samples(X_train,
                                                y_train,
                                                X_val,
   (...)
                                                model_path=model_path,
                                                class_weight=class_weight)
--> best_model_index = np.argmax(val_performance[metric])

ValueError: attempt to get argmax of an empty sequence

The error occurs because find_best_architecture() stores the metric scores on the validation set in the dictionary val_performance and uses the metric string as a key. When supplying a keras metrics object, an empty list is returned because the dictionary has no corresponding metric string.

maltelueken avatar Dec 20 '22 11:12 maltelueken