Peak memory usage goes up quadratically with d_t when X is not None
For our usecase we want to use DML with a continuous outcome variable, a discrete treatment (5000 levels) and 20-30 features. We want to use DML as it allows us to specify a regularized model as the final_model (we want to minimize the R-loss and don't care about unbiased estimates per se).
DML works as long is X is empty. But when X is not empty, the peak memory usage seems to go up quadratically with the number of treatments.
The code below, with only 500 treatments, gives the following error on my laptop (which only has 16gb of RAM): numpy.core._exceptions._ArrayMemoryError: Unable to allocate 27.8 GiB for an array with shape (2495000, 499, 3) and data type float64
In the stack trace it's combine_ which calls cross_product is at the source of the error.
from econml.dml import DML
from sklearn.dummy import DummyClassifier, DummyRegressor
import numpy as np
num_lines = 5000
np.random.seed(42)
X = np.random.normal(size=(num_lines, 2))
T = np.random.randint(500, size=num_lines)
y = 3*(T==1) -2*(T==2) + np.random.normal(size=(num_lines,))
est = DML(
model_y=DummyRegressor(),
model_t=DummyClassifier(),
model_final=DummyRegressor(),
discrete_treatment=True
)
est.fit(y, T, X=X)
I looked at this code because I thought it was related to a problem I'm having (which it isn't). But just for the record if you change
est = DML(model_y=DummyRegressor(), model_t=DummyClassifier(), model_final=DummyRegressor(), discrete_treatment=True, cv=1)
to
est = DML(model_y=DummyRegressor(), model_t=DummyRegressor(), model_final=DummyRegressor(), discrete_treatment=False, cv=1)
the problem seems to disappear.
@TimCosemans Thanks but we really need a discrete treatment in our usecase. I admit the example is a bit weird.
In the meantime, we have a better grasp of the issue. For the final stage of the model, if X is not None, DML calls upon the broadcast_unit_treatments() function, which seems to create an array of size (#rows, (#treatments)^2), which blows up our memory usage.
Updated issue description (better explanation of use case, simpler code example).