qlib icon indicating copy to clipboard operation
qlib copied to clipboard

predict() function definition lacks an optional argument "segment"

Open teancake opened this issue 1 year ago • 0 comments

The definition of function predict()

def predict(self, dataset):

in files

qlib/contrib/model/pytorch_lstm_ts.py
qlib/contrib/model/pytorch_tcn_ts.py

lacks an optional argument "segment", as defined in the base class

class Model(BaseModel):

    def predict(self, dataset: Dataset, segment: Union[Text, slice] = "test") -> object:

This causes predict() in LSTM and TCN to fail when segment for prediction is specified explicitly, and is not compatible with other models such as LGBModel, e.g.

model.predict(dataset=dataset, segment="pred")

Possible revision:

    def predict(self, dataset: Dataset, segment: Union[Text, slice] = "test") -> object:
        if not self.fitted:
            raise ValueError("model is not fitted yet!")
        dl_test = dataset.prepare(segment, col_set=["feature", "label"], data_key=DataHandlerLP.DK_I)

teancake avatar Jul 03 '24 14:07 teancake