when use interpret function of SingleTreeCateInterpreter, some error occured
i am using DML to inference, the DML model has finished, but what interpret the model, some error occured as followed: InvalidParameterError: The 'criterion' parameter of DecisionTreeRegressor must be a str among {'friedman_mse', 'absolute_error', 'poisson', 'squared_error'}. Got 'mse' instead.
the coed as followed:
**analytic_data = data[data['session_num'].isin([exp_session, ctr_session])] y = analytic_data['ret_1'].astype("float").values X = analytic_data.drop(['t0.member_id', 'session_num', 'ret_1'], axis=1) analytic_data['Treatment'] = analytic_data.apply(lambda row: 0 if row['session_num'] == ctr_session else 1, axis=1) T = analytic_data['Treatment'].astype("float").values
X_train, X_test, y_train, y_test, T_train, T_test = train_test_split(X, y, T, test_size=0.5, random_state=101) est = CausalForestDML( model_y=RandomForestRegressor(n_jobs=-1, max_depth=10, n_estimators=20), model_t=RandomForestRegressor(n_jobs=-1, max_depth=10, n_estimators=20), n_estimators=100, max_depth=8, min_samples_leaf=20000 ) est.fit(Y=y_train, T=T_train, X=X_train, W=None)
intrp = SingleTreeCateInterpreter(max_depth=3, random_state=30) intrp.interpret(est, X_test) plt.figure(figsize=(15, 8)) intrp.plot(feature_names=X_test.columns, fontsize=12)**
and the whole error text as followed:
InvalidParameterError Traceback (most recent call last) Cell In[15], line 2 1 intrp = SingleTreeCateInterpreter(max_depth=3, random_state=30) ----> 2 intrp.interpret(est, X_test) 3 plt.figure(figsize=(15, 8)) 4 intrp.plot(feature_names=X_test.columns, fontsize=12)
File ~/anaconda3/lib/python3.10/site-packages/econml/cate_interpreter/interpreters.py:193, in SingleTreeCateInterpreter.interpret(self, cate_estimator, X) 181 self.tree_model = DecisionTreeRegressor(criterion=self.criterion, 182 splitter=self.splitter, 183 max_depth=self.max_depth, (...) 189 max_leaf_nodes=self.max_leaf_nodes, 190 min_impurity_decrease=self.min_impurity_decrease) 191 y_pred = cate_estimator.const_marginal_effect(X) --> 193 self.tree_model_.fit(X, y_pred.reshape((y_pred.shape[0], -1))) 194 paths = self.tree_model_.decision_path(X) 195 node_dict = {}
File ~/anaconda3/lib/python3.10/site-packages/sklearn/tree/_classes.py:1247, in DecisionTreeRegressor.fit(self, X, y, sample_weight, check_input) 1218 def fit(self, X, y, sample_weight=None, check_input=True): 1219 """Build a decision tree regressor from the training set (X, y). 1220 1221 Parameters (...) 1244 Fitted estimator. 1245 """ -> 1247 super().fit( 1248 X, 1249 y, 1250 sample_weight=sample_weight, 1251 check_input=check_input, 1252 ) 1253 return self
File ~/anaconda3/lib/python3.10/site-packages/sklearn/tree/_classes.py:177, in BaseDecisionTree.fit(self, X, y, sample_weight, check_input) 176 def fit(self, X, y, sample_weight=None, check_input=True): --> 177 self._validate_params() 178 random_state = check_random_state(self.random_state) 180 if check_input: 181 # Need to validate separately here. 182 # We can't pass multi_output=True because that would allow y to be 183 # csr.
File ~/anaconda3/lib/python3.10/site-packages/sklearn/base.py:581, in BaseEstimator._validate_params(self)
573 def _validate_params(self):
574 """Validate types and values of constructor parameters
575
576 The expected type and values must be defined in the _parameter_constraints
(...)
579 accepted constraints.
580 """
--> 581 validate_parameter_constraints(
582 self._parameter_constraints,
583 self.get_params(deep=False),
584 caller_name=self.class.name,
585 )
File ~/anaconda3/lib/python3.10/site-packages/sklearn/utils/_param_validation.py:97, in validate_parameter_constraints(parameter_constraints, params, caller_name) 91 else: 92 constraints_str = ( 93 f"{', '.join([str(c) for c in constraints[:-1]])} or" 94 f" {constraints[-1]}" 95 ) ---> 97 raise InvalidParameterError( 98 f"The {param_name!r} parameter of {caller_name} must be" 99 f" {constraints_str}. Got {param_val!r} instead." 100 )
InvalidParameterError: The 'criterion' parameter of DecisionTreeRegressor must be a str among {'friedman_mse', 'absolute_error', 'poisson', 'squared_error'}. Got 'mse' instead.