tpot icon indicating copy to clipboard operation
tpot copied to clipboard

add classes_ to PytorchLRClassifier

Open perib opened this issue 1 year ago • 0 comments

[please review the Contribution Guidelines prior to submitting your pull request. go ahead and delete this line if you've already reviewed said guidelines.]

What does this PR do?

add a classes_ attribute to the pytorch classifiers

Where should the reviewer start?

How should this PR be tested?

from tpot import TPOTClassifier from sklearn.datasets import make_blobs from sklearn.model_selection import train_test_split

X, y = make_blobs(n_samples=100, centers=2, n_features=3, random_state=42) X_train, X_test, y_train, y_test = train_test_split(X, y, train_size=0.75, test_size=0.25)

clf = TPOTClassifier(config_dict='TPOT NN', template='Selector-Transformer-PytorchLRClassifier', verbosity=2, population_size=2, generations=2) clf.fit(X_train, y_train) print(clf.score(X_test, y_test)) clf.export('tpot_nn_demo_pipeline.py')

What are the relevant issues?

#1339

perib avatar Apr 22 '24 20:04 perib