EconML icon indicating copy to clipboard operation
EconML copied to clipboard

Peak memory usage goes up quadratically with d_t when X is not None

Open francis-doornaert-axa opened this issue 1 year ago • 3 comments

For our usecase we want to use DML with a continuous outcome variable, a discrete treatment (5000 levels) and 20-30 features. We want to use DML as it allows us to specify a regularized model as the final_model (we want to minimize the R-loss and don't care about unbiased estimates per se).

DML works as long is X is empty. But when X is not empty, the peak memory usage seems to go up quadratically with the number of treatments.

The code below, with only 500 treatments, gives the following error on my laptop (which only has 16gb of RAM): numpy.core._exceptions._ArrayMemoryError: Unable to allocate 27.8 GiB for an array with shape (2495000, 499, 3) and data type float64

In the stack trace it's combine_ which calls cross_product is at the source of the error.

from econml.dml import DML
from sklearn.dummy import DummyClassifier, DummyRegressor
import numpy as np

num_lines = 5000
np.random.seed(42)
X = np.random.normal(size=(num_lines, 2))
T = np.random.randint(500, size=num_lines)
y = 3*(T==1) -2*(T==2) + np.random.normal(size=(num_lines,))
est = DML(
    model_y=DummyRegressor(),
    model_t=DummyClassifier(),
    model_final=DummyRegressor(),
    discrete_treatment=True
)
est.fit(y, T, X=X)

francis-doornaert-axa avatar Feb 06 '25 14:02 francis-doornaert-axa

I looked at this code because I thought it was related to a problem I'm having (which it isn't). But just for the record if you change

est = DML(model_y=DummyRegressor(), model_t=DummyClassifier(), model_final=DummyRegressor(), discrete_treatment=True, cv=1)

to

est = DML(model_y=DummyRegressor(), model_t=DummyRegressor(), model_final=DummyRegressor(), discrete_treatment=False, cv=1)

the problem seems to disappear.

TimCosemans avatar Mar 12 '25 09:03 TimCosemans

@TimCosemans Thanks but we really need a discrete treatment in our usecase. I admit the example is a bit weird.

In the meantime, we have a better grasp of the issue. For the final stage of the model, if X is not None, DML calls upon the broadcast_unit_treatments() function, which seems to create an array of size (#rows, (#treatments)^2), which blows up our memory usage.

francis-doornaert-axa avatar Mar 12 '25 15:03 francis-doornaert-axa

Updated issue description (better explanation of use case, simpler code example).

francis-doornaert-axa avatar Mar 13 '25 15:03 francis-doornaert-axa