Built-in Pytorch Classifiers Crash TPOT
The two current pytorch-based TPOT classifiers are currently designed to support only binary targets, which I discovered in reviewing the GitHub repo. However, rather than failing safely, they crash TPOT when presented with a multi-class problem., e,g, MNIST digits.
Context of the issue
I have been using TPOT with non-DL models and have been impressed with its ability to construct complex pipelines efficiently, with minimum guidance from the user. Much of my work involves deep learning, so I was happy to see that TPOT is supporting pytorch. After trying the TPOT NN example, which failed mightily on my system, I created a configuration dictionary (attached) that included only the two pytorch classifiers. These also caused a crash, but in doing so uncovered the specific issue, no current support for multi-class problems.
Process to reproduce the issue
- Load and run script Hale_Ex1_NN, making sure it can import from TPOT_NN_Cfg_T1. (Both scripts attached.)
- Observe failures as shown below.
_Note: Duplicate pre-test decorator error messages removed.
Expected result
Ideally, the classifiers should support multi-class cases, but, failing that, avoid crashing the system, with an appropriate message, e.g., multi-class support not yet available.
Current result
/Users/david/opt/anaconda3/envs/TPOT_Torch/bin/python /Applications/PyCharm.app/Contents/plugins/python/helpers/pydev/pydevconsole.py --mode=client --port=49961
import sys; print('Python %s on %s' % (sys.version, sys.platform))
sys.path.extend(['/Users/david/PycharmProjects/KJStraddle'])
Python 3.8.3 (default, Jul 2 2020, 11:26:31)
Type 'copyright', 'credits' or 'license' for more information
IPython 7.19.0 -- An enhanced Interactive Python. Type '?' for help.
PyDev console: using IPython 7.19.0
Python 3.8.3 (default, Jul 2 2020, 11:26:31)
[Clang 10.0.0 ] on darwin
runfile('/Users/david/PycharmProjects/KJStraddle/Hale_Ex1_NN.py', wdir='/Users/david/PycharmProjects/KJStraddle')
Starting iteration 0
7 operators have been imported by TPOT.
_pre_test decorator: _random_mutation_operator: num_test=0 Non-binary targets not supported.
_pre_test decorator: _random_mutation_operator: num_test=1 Non-binary targets not supported.
_pre_test decorator: _random_mutation_operator: num_test=2 Non-binary targets not supported.
_pre_test decorator: _random_mutation_operator: num_test=3 Non-binary targets not supported.
_pre_test decorator: _random_mutation_operator: num_test=4 Non-binary targets not supported.
_pre_test decorator: _random_mutation_operator: num_test=5 Non-binary targets not supported.
_pre_test decorator: _random_mutation_operator: num_test=6 Non-binary targets not supported.
_pre_test decorator: _random_mutation_operator: num_test=7 Non-binary targets not supported.
_pre_test decorator: _random_mutation_operator: num_test=8 Non-binary targets not supported.
_pre_test decorator: _random_mutation_operator: num_test=9 Non-binary targets not supported.
_pre_test decorator: _mate_operator: num_test=0 Non-binary targets not supported.
_pre_test decorator: _mate_operator: num_test=1 Non-binary targets not supported.
_pre_test decorator: _mate_operator: num_test=2 Non-binary targets not supported.
_pre_test decorator: _mate_operator: num_test=3 Non-binary targets not supported.
_pre_test decorator: _mate_operator: num_test=4 Non-binary targets not supported.
_pre_test decorator: _mate_operator: num_test=5 Non-binary targets not supported.
_pre_test decorator: _mate_operator: num_test=6 Non-binary targets not supported.
_pre_test decorator: _mate_operator: num_test=7 Non-binary targets not supported.
_pre_test decorator: _mate_operator: num_test=8 Non-binary targets not supported.
_pre_test decorator: _mate_operator: num_test=9 Non-binary targets not supported.
Pipeline encountered that has previously been evaluated during the optimization process. Using the score from the previous evaluation.
Generation 1 - Current Pareto front scores:
-1 -inf PytorchMLPClassifier(input_matrix, PytorchMLPClassifier__batch_size=16, PytorchMLPClassifier__learning_rate=0.01, PytorchMLPClassifier__num_epochs=10, PytorchMLPClassifier__weight_decay=0.001)
/Users/david/opt/anaconda3/envs/TPOT_Torch/lib/python3.8/site-packages/sklearn/model_selection/_split.py:670: UserWarning: The least populated class in y has only 1 members, which is less than n_splits=5.
warnings.warn(("The least populated class in y has only %d"
/Users/david/opt/anaconda3/envs/TPOT_Torch/lib/python3.8/site-packages/sklearn/utils/validation.py:67: FutureWarning: Pass allow_nan=[8 0 6 4 2 4 0 3 3 3 9 3 6 8 4 4 1 5 1 6 7 6 6 1 0 9 2 6 4 0 5 5 9 4 2 6 5
9 4 8] as keyword args. From version 0.25 passing these as positional arguments will result in an error
warnings.warn("Pass {} as keyword args. From version 0.25 "
/Users/david/opt/anaconda3/envs/TPOT_Torch/lib/python3.8/site-packages/sklearn/model_selection/_split.py:670: UserWarning: The least populated class in y has only 1 members, which is less than n_splits=5.
warnings.warn(("The least populated class in y has only %d"
/Users/david/opt/anaconda3/envs/TPOT_Torch/lib/python3.8/site-packages/sklearn/utils/validation.py:67: FutureWarning: Pass allow_nan=[8 0 6 4 2 4 0 3 3 3 9 3 6 8 4 4 1 5 1 6 7 6 6 1 0 9 2 6 4 0 5 5 9 4 2 6 5
9 4 8] as keyword args. From version 0.25 passing these as positional arguments will result in an error
warnings.warn("Pass {} as keyword args. From version 0.25 "
/Users/david/opt/anaconda3/envs/TPOT_Torch/lib/python3.8/site-packages/sklearn/model_selection/_split.py:670: UserWarning: The least populated class in y has only 1 members, which is less than n_splits=5.
warnings.warn(("The least populated class in y has only %d"
/Users/david/opt/anaconda3/envs/TPOT_Torch/lib/python3.8/site-packages/sklearn/utils/validation.py:67: FutureWarning: Pass allow_nan=[8 0 6 4 2 4 0 3 3 3 9 3 6 8 4 4 1 5 1 6 7 6 6 1 0 9 2 6 4 0 5 5 9 4 2 6 5
9 4 8] as keyword args. From version 0.25 passing these as positional arguments will result in an error
warnings.warn("Pass {} as keyword args. From version 0.25 "
/Users/david/opt/anaconda3/envs/TPOT_Torch/lib/python3.8/site-packages/sklearn/model_selection/_split.py:670: UserWarning: The least populated class in y has only 1 members, which is less than n_splits=5.
warnings.warn(("The least populated class in y has only %d"
/Users/david/opt/anaconda3/envs/TPOT_Torch/lib/python3.8/site-packages/sklearn/utils/validation.py:67: FutureWarning: Pass allow_nan=[8 0 6 4 2 4 0 3 3 3 9 3 6 8 4 4 1 5 1 6 7 6 6 1 0 9 2 6 4 0 5 5 9 4 2 6 5
9 4 8] as keyword args. From version 0.25 passing these as positional arguments will result in an error
warnings.warn("Pass {} as keyword args. From version 0.25 "
/Users/david/opt/anaconda3/envs/TPOT_Torch/lib/python3.8/site-packages/sklearn/model_selection/_split.py:670: UserWarning: The least populated class in y has only 1 members, which is less than n_splits=5.
warnings.warn(("The least populated class in y has only %d"
/Users/david/opt/anaconda3/envs/TPOT_Torch/lib/python3.8/site-packages/sklearn/utils/validation.py:67: FutureWarning: Pass allow_nan=[8 0 6 4 2 4 0 3 3 3 9 3 6 8 4 4 1 5 1 6 7 6 6 1 0 9 2 6 4 0 5 5 9 4 2 6 5
9 4 8] as keyword args. From version 0.25 passing these as positional arguments will result in an error
warnings.warn("Pass {} as keyword args. From version 0.25 "
/Users/david/opt/anaconda3/envs/TPOT_Torch/lib/python3.8/site-packages/sklearn/model_selection/_split.py:670: UserWarning: The least populated class in y has only 1 members, which is less than n_splits=5.
warnings.warn(("The least populated class in y has only %d"
/Users/david/opt/anaconda3/envs/TPOT_Torch/lib/python3.8/site-packages/sklearn/utils/validation.py:67: FutureWarning: Pass allow_nan=[8 0 6 4 2 4 0 3 3 3 9 3 6 8 4 4 1 5 1 6 7 6 6 1 0 9 2 6 4 0 5 5 9 4 2 6 5
9 4 8] as keyword args. From version 0.25 passing these as positional arguments will result in an error
warnings.warn("Pass {} as keyword args. From version 0.25 "
/Users/david/opt/anaconda3/envs/TPOT_Torch/lib/python3.8/site-packages/sklearn/model_selection/_split.py:670: UserWarning: The least populated class in y has only 1 members, which is less than n_splits=5.
warnings.warn(("The least populated class in y has only %d"
/Users/david/opt/anaconda3/envs/TPOT_Torch/lib/python3.8/site-packages/sklearn/utils/validation.py:67: FutureWarning: Pass allow_nan=[8 0 6 4 2 4 0 3 3 3 9 3 6 8 4 4 1 5 1 6 7 6 6 1 0 9 2 6 4 0 5 5 9 4 2 6 5
9 4 8] as keyword args. From version 0.25 passing these as positional arguments will result in an error
warnings.warn("Pass {} as keyword args. From version 0.25 "
/Users/david/opt/anaconda3/envs/TPOT_Torch/lib/python3.8/site-packages/sklearn/model_selection/_split.py:670: UserWarning: The least populated class in y has only 1 members, which is less than n_splits=5.
warnings.warn(("The least populated class in y has only %d"
/Users/david/opt/anaconda3/envs/TPOT_Torch/lib/python3.8/site-packages/sklearn/utils/validation.py:67: FutureWarning: Pass allow_nan=[8 0 6 4 2 4 0 3 3 3 9 3 6 8 4 4 1 5 1 6 7 6 6 1 0 9 2 6 4 0 5 5 9 4 2 6 5
9 4 8] as keyword args. From version 0.25 passing these as positional arguments will result in an error
warnings.warn("Pass {} as keyword args. From version 0.25 "
/Users/david/opt/anaconda3/envs/TPOT_Torch/lib/python3.8/site-packages/sklearn/model_selection/_split.py:670: UserWarning: The least populated class in y has only 1 members, which is less than n_splits=5.
warnings.warn(("The least populated class in y has only %d"
/Users/david/opt/anaconda3/envs/TPOT_Torch/lib/python3.8/site-packages/sklearn/utils/validation.py:67: FutureWarning: Pass allow_nan=[8 0 6 4 2 4 0 3 3 3 9 3 6 8 4 4 1 5 1 6 7 6 6 1 0 9 2 6 4 0 5 5 9 4 2 6 5
9 4 8] as keyword args. From version 0.25 passing these as positional arguments will result in an error
warnings.warn("Pass {} as keyword args. From version 0.25 "
/Users/david/opt/anaconda3/envs/TPOT_Torch/lib/python3.8/site-packages/sklearn/model_selection/_split.py:670: UserWarning: The least populated class in y has only 1 members, which is less than n_splits=5.
warnings.warn(("The least populated class in y has only %d"
/Users/david/opt/anaconda3/envs/TPOT_Torch/lib/python3.8/site-packages/sklearn/utils/validation.py:67: FutureWarning: Pass allow_nan=[8 0 6 4 2 4 0 3 3 3 9 3 6 8 4 4 1 5 1 6 7 6 6 1 0 9 2 6 4 0 5 5 9 4 2 6 5
9 4 8] as keyword args. From version 0.25 passing these as positional arguments will result in an error
warnings.warn("Pass {} as keyword args. From version 0.25 "
Traceback (most recent call last):
File "/Users/david/opt/anaconda3/envs/TPOT_Torch/lib/python3.8/site-packages/tpot/base.py", line 730, in fit
self._pop, _ = eaMuPlusLambda(
File "/Users/david/opt/anaconda3/envs/TPOT_Torch/lib/python3.8/site-packages/tpot/gp_deap.py", line 281, in eaMuPlusLambda
per_generation_function(gen)
File "/Users/david/opt/anaconda3/envs/TPOT_Torch/lib/python3.8/site-packages/tpot/base.py", line 1052, in _check_periodic_pipeline
self._update_top_pipeline()
File "/Users/david/opt/anaconda3/envs/TPOT_Torch/lib/python3.8/site-packages/tpot/base.py", line 830, in _update_top_pipeline
cv_scores = cross_val_score(sklearn_pipeline,
File "/Users/david/opt/anaconda3/envs/TPOT_Torch/lib/python3.8/site-packages/sklearn/utils/validation.py", line 72, in inner_f
return f(**kwargs)
File "/Users/david/opt/anaconda3/envs/TPOT_Torch/lib/python3.8/site-packages/sklearn/model_selection/_validation.py", line 401, in cross_val_score
cv_results = cross_validate(estimator=estimator, X=X, y=y, groups=groups,
File "/Users/david/opt/anaconda3/envs/TPOT_Torch/lib/python3.8/site-packages/sklearn/utils/validation.py", line 72, in inner_f
return f(**kwargs)
File "/Users/david/opt/anaconda3/envs/TPOT_Torch/lib/python3.8/site-packages/sklearn/model_selection/_validation.py", line 242, in cross_validate
scores = parallel(
File "/Users/david/opt/anaconda3/envs/TPOT_Torch/lib/python3.8/site-packages/joblib/parallel.py", line 1041, in __call__
if self.dispatch_one_batch(iterator):
File "/Users/david/opt/anaconda3/envs/TPOT_Torch/lib/python3.8/site-packages/joblib/parallel.py", line 859, in dispatch_one_batch
self._dispatch(tasks)
File "/Users/david/opt/anaconda3/envs/TPOT_Torch/lib/python3.8/site-packages/joblib/parallel.py", line 777, in _dispatch
job = self._backend.apply_async(batch, callback=cb)
File "/Users/david/opt/anaconda3/envs/TPOT_Torch/lib/python3.8/site-packages/joblib/_parallel_backends.py", line 208, in apply_async
result = ImmediateResult(func)
File "/Users/david/opt/anaconda3/envs/TPOT_Torch/lib/python3.8/site-packages/joblib/_parallel_backends.py", line 572, in __init__
self.results = batch()
File "/Users/david/opt/anaconda3/envs/TPOT_Torch/lib/python3.8/site-packages/joblib/parallel.py", line 262, in __call__
return [func(*args, **kwargs)
File "/Users/david/opt/anaconda3/envs/TPOT_Torch/lib/python3.8/site-packages/joblib/parallel.py", line 262, in <listcomp>
return [func(*args, **kwargs)
File "/Users/david/opt/anaconda3/envs/TPOT_Torch/lib/python3.8/site-packages/sklearn/model_selection/_validation.py", line 531, in _fit_and_score
estimator.fit(X_train, y_train, **fit_params)
File "/Users/david/opt/anaconda3/envs/TPOT_Torch/lib/python3.8/site-packages/sklearn/pipeline.py", line 335, in fit
self._final_estimator.fit(Xt, y, **fit_params_last_step)
File "/Users/david/opt/anaconda3/envs/TPOT_Torch/lib/python3.8/site-packages/tpot/builtins/nn.py", line 122, in fit
self._init_model(X, y)
File "/Users/david/opt/anaconda3/envs/TPOT_Torch/lib/python3.8/site-packages/tpot/builtins/nn.py", line 315, in _init_model
X, y = self.validate_inputs(X, y)
File "/Users/david/opt/anaconda3/envs/TPOT_Torch/lib/python3.8/site-packages/tpot/builtins/nn.py", line 166, in validate_inputs
raise ValueError("Non-binary targets not supported")
ValueError: Non-binary targets not supported
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "/Users/david/opt/anaconda3/envs/TPOT_Torch/lib/python3.8/site-packages/IPython/core/interactiveshell.py", line 3418, in run_code
exec(code_obj, self.user_global_ns, self.user_ns)
File "<ipython-input-2-3a4542fd0455>", line 1, in <module>
runfile('/Users/david/PycharmProjects/KJStraddle/Hale_Ex1_NN.py', wdir='/Users/david/PycharmProjects/KJStraddle')
File "/Applications/PyCharm.app/Contents/plugins/python/helpers/pydev/_pydev_bundle/pydev_umd.py", line 197, in runfile
pydev_imports.execfile(filename, global_vars, local_vars) # execute the script
File "/Applications/PyCharm.app/Contents/plugins/python/helpers/pydev/_pydev_imps/_pydev_execfile.py", line 18, in execfile
exec(compile(contents+"\n", file, 'exec'), glob, loc)
File "/Users/david/PycharmProjects/KJStraddle/Hale_Ex1_NN.py", line 34, in <module>
tpot.fit(X_train, y_train)
File "/Users/david/opt/anaconda3/envs/TPOT_Torch/lib/python3.8/site-packages/tpot/base.py", line 773, in fit
raise e
File "/Users/david/opt/anaconda3/envs/TPOT_Torch/lib/python3.8/site-packages/tpot/base.py", line 764, in fit
self._update_top_pipeline()
File "/Users/david/opt/anaconda3/envs/TPOT_Torch/lib/python3.8/site-packages/tpot/base.py", line 830, in _update_top_pipeline
cv_scores = cross_val_score(sklearn_pipeline,
File "/Users/david/opt/anaconda3/envs/TPOT_Torch/lib/python3.8/site-packages/sklearn/utils/validation.py", line 72, in inner_f
return f(**kwargs)
File "/Users/david/opt/anaconda3/envs/TPOT_Torch/lib/python3.8/site-packages/sklearn/model_selection/_validation.py", line 401, in cross_val_score
cv_results = cross_validate(estimator=estimator, X=X, y=y, groups=groups,
File "/Users/david/opt/anaconda3/envs/TPOT_Torch/lib/python3.8/site-packages/sklearn/utils/validation.py", line 72, in inner_f
return f(**kwargs)
File "/Users/david/opt/anaconda3/envs/TPOT_Torch/lib/python3.8/site-packages/sklearn/model_selection/_validation.py", line 242, in cross_validate
scores = parallel(
File "/Users/david/opt/anaconda3/envs/TPOT_Torch/lib/python3.8/site-packages/joblib/parallel.py", line 1041, in __call__
if self.dispatch_one_batch(iterator):
File "/Users/david/opt/anaconda3/envs/TPOT_Torch/lib/python3.8/site-packages/joblib/parallel.py", line 859, in dispatch_one_batch
self._dispatch(tasks)
File "/Users/david/opt/anaconda3/envs/TPOT_Torch/lib/python3.8/site-packages/joblib/parallel.py", line 777, in _dispatch
job = self._backend.apply_async(batch, callback=cb)
File "/Users/david/opt/anaconda3/envs/TPOT_Torch/lib/python3.8/site-packages/joblib/_parallel_backends.py", line 208, in apply_async
result = ImmediateResult(func)
File "/Users/david/opt/anaconda3/envs/TPOT_Torch/lib/python3.8/site-packages/joblib/_parallel_backends.py", line 572, in __init__
self.results = batch()
File "/Users/david/opt/anaconda3/envs/TPOT_Torch/lib/python3.8/site-packages/joblib/parallel.py", line 262, in __call__
return [func(*args, **kwargs)
File "/Users/david/opt/anaconda3/envs/TPOT_Torch/lib/python3.8/site-packages/joblib/parallel.py", line 262, in <listcomp>
return [func(*args, **kwargs)
File "/Users/david/opt/anaconda3/envs/TPOT_Torch/lib/python3.8/site-packages/sklearn/model_selection/_validation.py", line 531, in _fit_and_score
estimator.fit(X_train, y_train, **fit_params)
File "/Users/david/opt/anaconda3/envs/TPOT_Torch/lib/python3.8/site-packages/sklearn/pipeline.py", line 335, in fit
self._final_estimator.fit(Xt, y, **fit_params_last_step)
File "/Users/david/opt/anaconda3/envs/TPOT_Torch/lib/python3.8/site-packages/tpot/builtins/nn.py", line 122, in fit
self._init_model(X, y)
File "/Users/david/opt/anaconda3/envs/TPOT_Torch/lib/python3.8/site-packages/tpot/builtins/nn.py", line 315, in _init_model
X, y = self.validate_inputs(X, y)
File "/Users/david/opt/anaconda3/envs/TPOT_Torch/lib/python3.8/site-packages/tpot/builtins/nn.py", line 166, in validate_inputs
raise ValueError("Non-binary targets not supported")
ValueError: Non-binary targets not supported
Possible fix
I volunteer to 1) add safe failure code and 2) extend classifiers to support multi-target cases (or assist as needed).
Which version of scikit-learn and tpot is in your environment? I doubt that the latest version of scikit-learn (0.24) has an incompatibility issue. See #1157
Good to hear from you.
Scikit-learn 0.23.2 TPOT 0.11.6.post3
I did not note any incompatibilities but assumed the cause was within the pytorch classifiers code, associated with the current limitation to binary use cases. Should I upgrade Scikit-learn to 0.24 and try again?
Thanks.
P.S. I am very impressed with TPOT and its ease of use and utility. I hope to see more models and deep learning options added over time, but, despite these limitations, TPOT is already my preferred autoML tool.
P.P.S Would you prefer I respond within GitHub?
David L. Wilt
3272 Bayou Road
Longboat Key FL 34228
540-420-0844
From: Weixuan Fu [email protected] Reply-To: EpistasisLab/tpot [email protected] Date: Tuesday, January 5, 2021 at 3:31 PM To: EpistasisLab/tpot [email protected] Cc: DLWCMD [email protected], Author [email protected] Subject: Re: [EpistasisLab/tpot] Built-in Pytorch Classifiers Crash TPOT (#1149)
Which version of scikit-learn and tpot is in your environment? I doubt that the latest version of scikit-learn (0.24) has an incompatibility issue. See #1157
— You are receiving this because you authored the thread. Reply to this email directly, view it on GitHub, or unsubscribe.
@JDRomano2 do you have any idea about this issue?
We are actively working on adding multi-class support (as well as support for regression) with PyTorch estimators, but in the meantime it is important we fix TPOT to fail gracefully.
@DLWCMD - Please read the contributing guide (https://epistasislab.github.io/tpot/contributing/) if you would like to add this safety check until we implement support for multi-class targets. We can also patch this check in ourselves; we should be able to do so within the next few days.
Basically, TPOT should raise an exception when TPOTClassifier.fit() is called if both of the following are true:
- The
TPOTClassifierobject was initialized with PyTorch estimators enabled. - The
classesargument is multi-class (i.e., non-binary).
Joe,
Good to hear from you, and thanks for the explanation.
I will read the contribution guide, but I am comfortable I can develop code for the situations you outlined. I will be in touch.
Assuming I pass this test (☺), I would be interested in supporting your pytorch new development.
Thanks for your good work with TPOT: I think it is breaking new ground in the autoML space.
David L. Wilt
3272 Bayou Road
Longboat Key FL 34228
540-420-0844
From: Joe Romano [email protected] Reply-To: EpistasisLab/tpot [email protected] Date: Tuesday, January 5, 2021 at 5:45 PM To: EpistasisLab/tpot [email protected] Cc: DLWCMD [email protected], Mention [email protected] Subject: Re: [EpistasisLab/tpot] Built-in Pytorch Classifiers Crash TPOT (#1149)
We are actively working on adding multi-class support (as well as support for regression) with PyTorch estimators, but in the meantime it is important we fix TPOT to fail gracefully.
@DLWCMD - Please read the contributing guide (https://epistasislab.github.io/tpot/contributing/) if you would like to add this safety check until we implement support for multi-class targets. We can also patch this check in ourselves; we should be able to do so within the next few days.
Basically, TPOT should raise an exception when TPOTClassifier.fit() is called if both of the following are true: The TPOTClassifier object was initialized with PyTorch estimators enabled. The classes argument is multi-class (i.e., non-binary). — You are receiving this because you were mentioned. Reply to this email directly, view it on GitHub, or unsubscribe.