qlib
qlib copied to clipboard
predict() function definition lacks an optional argument "segment"
The definition of function predict()
def predict(self, dataset):
in files
qlib/contrib/model/pytorch_lstm_ts.py
qlib/contrib/model/pytorch_tcn_ts.py
lacks an optional argument "segment", as defined in the base class
class Model(BaseModel):
def predict(self, dataset: Dataset, segment: Union[Text, slice] = "test") -> object:
This causes predict() in LSTM and TCN to fail when segment for prediction is specified explicitly, and is not compatible with other models such as LGBModel, e.g.
model.predict(dataset=dataset, segment="pred")
Possible revision:
def predict(self, dataset: Dataset, segment: Union[Text, slice] = "test") -> object:
if not self.fitted:
raise ValueError("model is not fitted yet!")
dl_test = dataset.prepare(segment, col_set=["feature", "label"], data_key=DataHandlerLP.DK_I)