bricks
bricks copied to clipboard
[MODULE] - Grid search active learner
Please describe the module you would like to add to the content library Sklearn-based grid search to train an active learner classification head
Implementation
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import GridSearchCV
class ActiveLearner:
def __init__(self, base_classifier, pool_size, query_strategy):
self.base_classifier = RandomForestClassifier()
self.pool_size = pool_size
self.query_strategy = query_strategy
def select_next_sample(self, X_pool, Y_pool):
if self.query_strategy == "uncertainty":
pass
elif self.query_strategy == "diversity":
pass
def fit(self, X_train, y_train, X_test, y_test):
self.base_classifier.fit(X_train, y_train)
y_pred = self.base_classifier.predict_proba(X_test)
return y_pred
def gridSearch():
param_grid = {
"n_estimators": [10, 50, 100],
"pool_size": [50, 100, 500]
}
rf_classifier = RandomForestClassifier()
active_learner = ActiveLearner(base_classifier=rf_classifier,
pool_size=500,
query_strategy='uncertainty')
grid_search = GridSearchCV(estimator=active_learner,
param_grid=param_grid,
scoring='accuracy')
# define the training data and test data
grid_search.fit(X_train, y_train)
y_pred = grid_search.predict(X_test)
return y_pred
Additional context -
Currently not useable in bricks as the integrator input is still to be implemented