MemoryError: Unable to allocate sufficient memory
Issue Description
I used a Random Survival Forest with 10 estimators and a max depth of 25 on approximately 1800 data samples. The full dataset otherwise contains approximately 200,000 data samples, but I intentionally only used a very small sample when I encountered this error.
When attempting to fit a ModelSurvSHAP on this very small dummy random survival forest I encounter the following error: MemoryError: Unable to allocate 512. TiB for an array with shape (8388608, 8388608) and data type float64
I'm using survshap version 0.4.2.
Minimal Reproducible Code Sample
rsf = RandomSurvivalForest(
n_estimators=10, max_depth=25, min_samples_split=10, min_samples_leaf=15, n_jobs=-1, random_state=random_state
)
rsf.fit(X_train, y_train)
from survshap import SurvivalModelExplainer, PredictSurvSHAP, ModelSurvSHAP
rsf_exp = SurvivalModelExplainer(rsf, X_test, y_test)
exp1_survshap_global_rsf = ModelSurvSHAP(random_state=42)
exp1_survshap_global_rsf.fit(rsf_exp)
Error Trace:
---------------------------------------------------------------------------
MemoryError Traceback (most recent call last)
Cell In[38], line 6
3 rsf_exp = SurvivalModelExplainer(rsf, X_test, y_test)
5 exp1_survshap_global_rsf = ModelSurvSHAP(random_state=42)
----> 6 exp1_survshap_global_rsf.fit(rsf_exp)
File c:\Users\alenk\anaconda3\envs\azureml_py310_sdkv2\lib\site-packages\survshap\model_explanations\object.py:76, in ModelSurvSHAP.fit(self, explainer, new_observations, timestamps, save_individual_explanations, **kwargs)
69 if new_observations is None:
70 new_observations = explainer.data
72 (
73 self.full_result,
74 self.individual_explanations,
75 self.timestamps,
---> 76 ) = calculate_individual_explanations(
77 explainer,
78 new_observations,
79 self.function_type,
80 self.path,
81 self.B,
82 self.max_shap_value_inputs,
83 self.random_state,
84 self.calculation_method,
85 self.aggregation_method,
86 timestamps,
87 save_individual_explanations,
88 **kwargs
89 )
91 names = explainer.y.dtype.names
92 self.event_ind = explainer.y[names[0]]
File c:\Users\alenk\anaconda3\envs\azureml_py310_sdkv2\lib\site-packages\survshap\model_explanations\utils.py:127, in calculate_individual_explanations(explainer, new_observations, function_type, path, B, max_shap_value_inputs, random_state, calculation_method, aggregation_method, timestamps, save_individual_explanations, **kwargs)
117 for i in tqdm(range(len(new_observations))):
118 survSHAP_obj = PredictSurvSHAP(
119 function_type=function_type,
120 path=path,
(...)
125 random_state=random_state,
126 )
--> 127 survSHAP_obj.fit(explainer, new_observations.iloc[[i]], timestamps)
128 if save_individual_explanations:
129 individual_explanations.append(survSHAP_obj)
File c:\Users\alenk\anaconda3\envs\azureml_py310_sdkv2\lib\site-packages\survshap\predict_explanations\object.py:81, in PredictSurvSHAP.fit(self, explainer, new_observation, timestamps, y_true)
72 self.y_true_time = y_true[names[1]]
74 if self.calculation_method == "kernel":
75 (
76 self.result,
77 self.predicted_function,
78 self.baseline_function,
79 self.timestamps,
80 self.r2,
---> 81 ) = shap_kernel(
82 explainer,
83 new_observation,
84 self.function,
85 self.aggregation_method,
86 timestamps,
87 self.max_shap_value_inputs,
88 )
89 elif self.calculation_method == "sampling":
90 (
91 self.result,
92 self.predicted_function,
(...)
104 self.exact,
105 )
File c:\Users\alenk\anaconda3\envs\azureml_py310_sdkv2\lib\site-packages\survshap\predict_explanations\utils.py:106, in shap_kernel(explainer, new_observation, function_type, aggregation_method, timestamps, max_shap_value_inputs)
101 print(
102 f"Approximate Survival Shapley will sample only {max_shap_value_inputs} values instead of 2**{p} for Exact Shapley"
103 )
105 kernel_weights = generate_shap_kernel_weights(simplified_inputs, p)
--> 106 shap_values, r2 = calculate_shap_values(
107 explainer,
108 function_type,
109 baseline_f,
110 explainer.data,
111 simplified_inputs,
112 kernel_weights,
113 new_observation,
114 timestamps,
115 )
117 variable_names = explainer.data.columns
118 result = prepare_result_df(new_observation, variable_names, shap_values, timestamps, aggregation_method)
File c:\Users\alenk\anaconda3\envs\azureml_py310_sdkv2\lib\site-packages\survshap\predict_explanations\utils.py:158, in calculate_shap_values(model, function_type, avg_function, data, simplified_inputs, shap_kernel_weights, new_observation, timestamps)
148 def calculate_shap_values(
149 model,
150 function_type,
(...)
156 timestamps,
157 ):
--> 158 W = np.diag(shap_kernel_weights)
159 X = np.array(simplified_inputs)
160 R = np.linalg.inv(X.T @ W @ X) @ (X.T @ W)
File c:\Users\alenk\anaconda3\envs\azureml_py310_sdkv2\lib\site-packages\numpy\lib\twodim_base.py:293, in diag(v, k)
291 if len(s) == 1:
292 n = s[0]+abs(k)
--> 293 res = zeros((n, n), v.dtype)
294 if k >= 0:
295 i = k
MemoryError: Unable to allocate 512. TiB for an array with shape (8388608, 8388608) and data type float64
According to the error message I see that you want to explain 8 million variables. The matrix raising the memory error consumes O(N^2) with N the number of input variables. It is not doable. You can use the attribute max_shap_value_inputs to sample and approximate Shapley, but max_shap_value_inputs should be superior to N and 8Millions (features) will still take a long time. In conclusion, I fear you need to reformulate the problem you want to solve before using this tool by reducing the number of input variables.
You can try the code below with different combinations of nb_features, nb_events, and max_shap_value_inputs and see when it fits your problem and your computer.
import numpy as np
np.random.seed(42)
import pandas as pd
def gen_xy(nb_features=10, nb_events=10, x_type=np.float32, date_type=np.float32):
X_train=np.random.rand(nb_events, nb_features).astype(x_type)
np_time=np.random.rand(nb_events)
noise=np.clip(0,1, np.random.rand(nb_events)*0.01)
np_is_living=X_train[:,0] < np_time+noise # <--- dumb y
y_train=np.empty(nb_events, dtype=[('event', bool), ('time', date_type)])
y_train['event']=np_is_living
y_train['time']=np_time
X_train=pd.DataFrame(X_train,columns=['f'+str(i) for i in range(1,nb_features+1)])
return X_train, y_train
X_train, y_train = gen_xy(nb_events=1800)
X_test, y_test = gen_xy(nb_events=1800)
from sksurv.ensemble import RandomSurvivalForest
rsf = RandomSurvivalForest(
n_estimators=10, max_depth=25, min_samples_split=10, min_samples_leaf=15, n_jobs=-1, random_state=42
)
rsf.fit(X_train, y_train)
from survshap import SurvivalModelExplainer, PredictSurvSHAP, ModelSurvSHAP
rsf_exp = SurvivalModelExplainer(rsf, X_test, y_test)
exp1_survshap_global_rsf = ModelSurvSHAP(random_state=42, max_shap_value_inputs=20)
exp1_survshap_global_rsf.fit(rsf_exp)
I have no idea where the 8mio variables would be coming from. The dataset I was testing this on had only 23 input variables which is pretty basic.
I tried running the sample code you shared with max_shap_value_inputs parameter. The sample as provided with max_shap_value_inputs=20 also fails to complete. It hangs at about 17%. No error and no progress after more than 2 hours. Not running any other significant processes on my machine at the time.
I tried with even smaller number of input variables (5) and training dataset with only 100 records. This time I get a LinAlgError: Singular matrix.
---------------------------------------------------------------------------
LinAlgError Traceback (most recent call last)
Cell In[4], [line 12](vscode-notebook-cell:?execution_count=4&line=12)
[9](vscode-notebook-cell:?execution_count=4&line=9) rsf_exp = SurvivalModelExplainer(rsf, X_test, y_test)
[11](vscode-notebook-cell:?execution_count=4&line=11) exp1_survshap_global_rsf = ModelSurvSHAP(random_state=42, max_shap_value_inputs=5)
---> [12](vscode-notebook-cell:?execution_count=4&line=12) exp1_survshap_global_rsf.fit(rsf_exp)
File c:\Users\alenk\anaconda3\envs\py311_survival\Lib\site-packages\survshap\model_explanations\object.py:76, in ModelSurvSHAP.fit(self, explainer, new_observations, timestamps, save_individual_explanations, **kwargs)
[69](file:///C:/Users/alenk/anaconda3/envs/py311_survival/Lib/site-packages/survshap/model_explanations/object.py:69) if new_observations is None:
[70](file:///C:/Users/alenk/anaconda3/envs/py311_survival/Lib/site-packages/survshap/model_explanations/object.py:70) new_observations = explainer.data
[72](file:///C:/Users/alenk/anaconda3/envs/py311_survival/Lib/site-packages/survshap/model_explanations/object.py:72) (
[73](file:///C:/Users/alenk/anaconda3/envs/py311_survival/Lib/site-packages/survshap/model_explanations/object.py:73) self.full_result,
[74](file:///C:/Users/alenk/anaconda3/envs/py311_survival/Lib/site-packages/survshap/model_explanations/object.py:74) self.individual_explanations,
[75](file:///C:/Users/alenk/anaconda3/envs/py311_survival/Lib/site-packages/survshap/model_explanations/object.py:75) self.timestamps,
---> [76](file:///C:/Users/alenk/anaconda3/envs/py311_survival/Lib/site-packages/survshap/model_explanations/object.py:76) ) = calculate_individual_explanations(
[77](file:///C:/Users/alenk/anaconda3/envs/py311_survival/Lib/site-packages/survshap/model_explanations/object.py:77) explainer,
[78](file:///C:/Users/alenk/anaconda3/envs/py311_survival/Lib/site-packages/survshap/model_explanations/object.py:78) new_observations,
[79](file:///C:/Users/alenk/anaconda3/envs/py311_survival/Lib/site-packages/survshap/model_explanations/object.py:79) self.function_type,
[80](file:///C:/Users/alenk/anaconda3/envs/py311_survival/Lib/site-packages/survshap/model_explanations/object.py:80) self.path,
[81](file:///C:/Users/alenk/anaconda3/envs/py311_survival/Lib/site-packages/survshap/model_explanations/object.py:81) self.B,
[82](file:///C:/Users/alenk/anaconda3/envs/py311_survival/Lib/site-packages/survshap/model_explanations/object.py:82) self.max_shap_value_inputs,
[83](file:///C:/Users/alenk/anaconda3/envs/py311_survival/Lib/site-packages/survshap/model_explanations/object.py:83) self.random_state,
[84](file:///C:/Users/alenk/anaconda3/envs/py311_survival/Lib/site-packages/survshap/model_explanations/object.py:84) self.calculation_method,
[85](file:///C:/Users/alenk/anaconda3/envs/py311_survival/Lib/site-packages/survshap/model_explanations/object.py:85) self.aggregation_method,
[86](file:///C:/Users/alenk/anaconda3/envs/py311_survival/Lib/site-packages/survshap/model_explanations/object.py:86) timestamps,
[87](file:///C:/Users/alenk/anaconda3/envs/py311_survival/Lib/site-packages/survshap/model_explanations/object.py:87) save_individual_explanations,
[88](file:///C:/Users/alenk/anaconda3/envs/py311_survival/Lib/site-packages/survshap/model_explanations/object.py:88) **kwargs
[89](file:///C:/Users/alenk/anaconda3/envs/py311_survival/Lib/site-packages/survshap/model_explanations/object.py:89) )
[91](file:///C:/Users/alenk/anaconda3/envs/py311_survival/Lib/site-packages/survshap/model_explanations/object.py:91) names = explainer.y.dtype.names
[92](file:///C:/Users/alenk/anaconda3/envs/py311_survival/Lib/site-packages/survshap/model_explanations/object.py:92) self.event_ind = explainer.y[names[0]]
File c:\Users\alenk\anaconda3\envs\py311_survival\Lib\site-packages\survshap\model_explanations\utils.py:127, in calculate_individual_explanations(explainer, new_observations, function_type, path, B, max_shap_value_inputs, random_state, calculation_method, aggregation_method, timestamps, save_individual_explanations, **kwargs)
[117](file:///C:/Users/alenk/anaconda3/envs/py311_survival/Lib/site-packages/survshap/model_explanations/utils.py:117) for i in tqdm(range(len(new_observations))):
[118](file:///C:/Users/alenk/anaconda3/envs/py311_survival/Lib/site-packages/survshap/model_explanations/utils.py:118) survSHAP_obj = PredictSurvSHAP(
[119](file:///C:/Users/alenk/anaconda3/envs/py311_survival/Lib/site-packages/survshap/model_explanations/utils.py:119) function_type=function_type,
[120](file:///C:/Users/alenk/anaconda3/envs/py311_survival/Lib/site-packages/survshap/model_explanations/utils.py:120) path=path,
(...)
[125](file:///C:/Users/alenk/anaconda3/envs/py311_survival/Lib/site-packages/survshap/model_explanations/utils.py:125) random_state=random_state,
[126](file:///C:/Users/alenk/anaconda3/envs/py311_survival/Lib/site-packages/survshap/model_explanations/utils.py:126) )
--> [127](file:///C:/Users/alenk/anaconda3/envs/py311_survival/Lib/site-packages/survshap/model_explanations/utils.py:127) survSHAP_obj.fit(explainer, new_observations.iloc[[i]], timestamps)
[128](file:///C:/Users/alenk/anaconda3/envs/py311_survival/Lib/site-packages/survshap/model_explanations/utils.py:128) if save_individual_explanations:
[129](file:///C:/Users/alenk/anaconda3/envs/py311_survival/Lib/site-packages/survshap/model_explanations/utils.py:129) individual_explanations.append(survSHAP_obj)
File c:\Users\alenk\anaconda3\envs\py311_survival\Lib\site-packages\survshap\predict_explanations\object.py:81, in PredictSurvSHAP.fit(self, explainer, new_observation, timestamps, y_true)
[72](file:///C:/Users/alenk/anaconda3/envs/py311_survival/Lib/site-packages/survshap/predict_explanations/object.py:72) self.y_true_time = y_true[names[1]]
[74](file:///C:/Users/alenk/anaconda3/envs/py311_survival/Lib/site-packages/survshap/predict_explanations/object.py:74) if self.calculation_method == "kernel":
[75](file:///C:/Users/alenk/anaconda3/envs/py311_survival/Lib/site-packages/survshap/predict_explanations/object.py:75) (
[76](file:///C:/Users/alenk/anaconda3/envs/py311_survival/Lib/site-packages/survshap/predict_explanations/object.py:76) self.result,
[77](file:///C:/Users/alenk/anaconda3/envs/py311_survival/Lib/site-packages/survshap/predict_explanations/object.py:77) self.predicted_function,
[78](file:///C:/Users/alenk/anaconda3/envs/py311_survival/Lib/site-packages/survshap/predict_explanations/object.py:78) self.baseline_function,
[79](file:///C:/Users/alenk/anaconda3/envs/py311_survival/Lib/site-packages/survshap/predict_explanations/object.py:79) self.timestamps,
[80](file:///C:/Users/alenk/anaconda3/envs/py311_survival/Lib/site-packages/survshap/predict_explanations/object.py:80) self.r2,
---> [81](file:///C:/Users/alenk/anaconda3/envs/py311_survival/Lib/site-packages/survshap/predict_explanations/object.py:81) ) = shap_kernel(
[82](file:///C:/Users/alenk/anaconda3/envs/py311_survival/Lib/site-packages/survshap/predict_explanations/object.py:82) explainer,
[83](file:///C:/Users/alenk/anaconda3/envs/py311_survival/Lib/site-packages/survshap/predict_explanations/object.py:83) new_observation,
[84](file:///C:/Users/alenk/anaconda3/envs/py311_survival/Lib/site-packages/survshap/predict_explanations/object.py:84) self.function,
[85](file:///C:/Users/alenk/anaconda3/envs/py311_survival/Lib/site-packages/survshap/predict_explanations/object.py:85) self.aggregation_method,
[86](file:///C:/Users/alenk/anaconda3/envs/py311_survival/Lib/site-packages/survshap/predict_explanations/object.py:86) timestamps,
[87](file:///C:/Users/alenk/anaconda3/envs/py311_survival/Lib/site-packages/survshap/predict_explanations/object.py:87) self.max_shap_value_inputs,
[88](file:///C:/Users/alenk/anaconda3/envs/py311_survival/Lib/site-packages/survshap/predict_explanations/object.py:88) )
[89](file:///C:/Users/alenk/anaconda3/envs/py311_survival/Lib/site-packages/survshap/predict_explanations/object.py:89) elif self.calculation_method == "sampling":
[90](file:///C:/Users/alenk/anaconda3/envs/py311_survival/Lib/site-packages/survshap/predict_explanations/object.py:90) (
[91](file:///C:/Users/alenk/anaconda3/envs/py311_survival/Lib/site-packages/survshap/predict_explanations/object.py:91) self.result,
[92](file:///C:/Users/alenk/anaconda3/envs/py311_survival/Lib/site-packages/survshap/predict_explanations/object.py:92) self.predicted_function,
(...)
[104](file:///C:/Users/alenk/anaconda3/envs/py311_survival/Lib/site-packages/survshap/predict_explanations/object.py:104) self.exact,
[105](file:///C:/Users/alenk/anaconda3/envs/py311_survival/Lib/site-packages/survshap/predict_explanations/object.py:105) )
File c:\Users\alenk\anaconda3\envs\py311_survival\Lib\site-packages\survshap\predict_explanations\utils.py:106, in shap_kernel(explainer, new_observation, function_type, aggregation_method, timestamps, max_shap_value_inputs)
[101](file:///C:/Users/alenk/anaconda3/envs/py311_survival/Lib/site-packages/survshap/predict_explanations/utils.py:101) print(
[102](file:///C:/Users/alenk/anaconda3/envs/py311_survival/Lib/site-packages/survshap/predict_explanations/utils.py:102) f"Approximate Survival Shapley will sample only {max_shap_value_inputs} values instead of 2**{p} for Exact Shapley"
[103](file:///C:/Users/alenk/anaconda3/envs/py311_survival/Lib/site-packages/survshap/predict_explanations/utils.py:103) )
[105](file:///C:/Users/alenk/anaconda3/envs/py311_survival/Lib/site-packages/survshap/predict_explanations/utils.py:105) kernel_weights = generate_shap_kernel_weights(simplified_inputs, p)
--> [106](file:///C:/Users/alenk/anaconda3/envs/py311_survival/Lib/site-packages/survshap/predict_explanations/utils.py:106) shap_values, r2 = calculate_shap_values(
[107](file:///C:/Users/alenk/anaconda3/envs/py311_survival/Lib/site-packages/survshap/predict_explanations/utils.py:107) explainer,
[108](file:///C:/Users/alenk/anaconda3/envs/py311_survival/Lib/site-packages/survshap/predict_explanations/utils.py:108) function_type,
[109](file:///C:/Users/alenk/anaconda3/envs/py311_survival/Lib/site-packages/survshap/predict_explanations/utils.py:109) baseline_f,
[110](file:///C:/Users/alenk/anaconda3/envs/py311_survival/Lib/site-packages/survshap/predict_explanations/utils.py:110) explainer.data,
[111](file:///C:/Users/alenk/anaconda3/envs/py311_survival/Lib/site-packages/survshap/predict_explanations/utils.py:111) simplified_inputs,
[112](file:///C:/Users/alenk/anaconda3/envs/py311_survival/Lib/site-packages/survshap/predict_explanations/utils.py:112) kernel_weights,
[113](file:///C:/Users/alenk/anaconda3/envs/py311_survival/Lib/site-packages/survshap/predict_explanations/utils.py:113) new_observation,
[114](file:///C:/Users/alenk/anaconda3/envs/py311_survival/Lib/site-packages/survshap/predict_explanations/utils.py:114) timestamps,
[115](file:///C:/Users/alenk/anaconda3/envs/py311_survival/Lib/site-packages/survshap/predict_explanations/utils.py:115) )
[117](file:///C:/Users/alenk/anaconda3/envs/py311_survival/Lib/site-packages/survshap/predict_explanations/utils.py:117) variable_names = explainer.data.columns
[118](file:///C:/Users/alenk/anaconda3/envs/py311_survival/Lib/site-packages/survshap/predict_explanations/utils.py:118) result = prepare_result_df(new_observation, variable_names, shap_values, timestamps, aggregation_method)
File c:\Users\alenk\anaconda3\envs\py311_survival\Lib\site-packages\survshap\predict_explanations\utils.py:160, in calculate_shap_values(model, function_type, avg_function, data, simplified_inputs, shap_kernel_weights, new_observation, timestamps)
[158](file:///C:/Users/alenk/anaconda3/envs/py311_survival/Lib/site-packages/survshap/predict_explanations/utils.py:158) W = np.diag(shap_kernel_weights)
[159](file:///C:/Users/alenk/anaconda3/envs/py311_survival/Lib/site-packages/survshap/predict_explanations/utils.py:159) X = np.array(simplified_inputs)
--> [160](file:///C:/Users/alenk/anaconda3/envs/py311_survival/Lib/site-packages/survshap/predict_explanations/utils.py:160) R = np.linalg.inv(X.T @ W @ X) @ (X.T @ W)
[161](file:///C:/Users/alenk/anaconda3/envs/py311_survival/Lib/site-packages/survshap/predict_explanations/utils.py:161) y = (
[162](file:///C:/Users/alenk/anaconda3/envs/py311_survival/Lib/site-packages/survshap/predict_explanations/utils.py:162) make_prediction_for_simplified_input(model, function_type, data, simplified_inputs, new_observation, timestamps)
[163](file:///C:/Users/alenk/anaconda3/envs/py311_survival/Lib/site-packages/survshap/predict_explanations/utils.py:163) - avg_function
[164](file:///C:/Users/alenk/anaconda3/envs/py311_survival/Lib/site-packages/survshap/predict_explanations/utils.py:164) )
[165](file:///C:/Users/alenk/anaconda3/envs/py311_survival/Lib/site-packages/survshap/predict_explanations/utils.py:165) shap_values = R @ y
File c:\Users\alenk\anaconda3\envs\py311_survival\Lib\site-packages\numpy\linalg\linalg.py:561, in inv(a)
[559](file:///C:/Users/alenk/anaconda3/envs/py311_survival/Lib/site-packages/numpy/linalg/linalg.py:559) signature = 'D->D' if isComplexType(t) else 'd->d'
[560](file:///C:/Users/alenk/anaconda3/envs/py311_survival/Lib/site-packages/numpy/linalg/linalg.py:560) extobj = get_linalg_error_extobj(_raise_linalgerror_singular)
--> [561](file:///C:/Users/alenk/anaconda3/envs/py311_survival/Lib/site-packages/numpy/linalg/linalg.py:561) ainv = _umath_linalg.inv(a, signature=signature, extobj=extobj)
[562](file:///C:/Users/alenk/anaconda3/envs/py311_survival/Lib/site-packages/numpy/linalg/linalg.py:562) return wrap(ainv.astype(result_t, copy=False))
File c:\Users\alenk\anaconda3\envs\py311_survival\Lib\site-packages\numpy\linalg\linalg.py:112, in _raise_linalgerror_singular(err, flag)
[111](file:///C:/Users/alenk/anaconda3/envs/py311_survival/Lib/site-packages/numpy/linalg/linalg.py:111) def _raise_linalgerror_singular(err, flag):
--> [112](file:///C:/Users/alenk/anaconda3/envs/py311_survival/Lib/site-packages/numpy/linalg/linalg.py:112) raise LinAlgError("Singular matrix")
LinAlgError: Singular matrix
Could you provide a minimal snippet of code that represents your data/code and raises the error?
If the computed matrix is not invertible, you may add a small amount of random noise in your data to avoid linear dependencies between columns.
@kaalen could you provide an update please? What did you do ?
I am also facing the same issue even with 10% of my data of about 16000.
Did anybody figure this out?