skops can't load TargetEncoder object
SKOPS is unable to load sklearn-pipeline having TargetEncoder object. The code below helps to re-produce the issue.
Simple code for a classifier using Target Encoder
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier
from category_encoders import TargetEncoder
from sklearn.metrics import accuracy_score
# Sample data
data = {
'category': ['A', 'B', 'A', 'C', 'B', 'A', 'C', 'C', 'B', 'A'],
'feature1': [10, 20, 10, 30, 20, 10, 30, 30, 20, 10],
'feature2': [1, 2, 1, 3, 2, 1, 3, 3, 2, 1],
'target': [0, 1, 0, 1, 1, 0, 1, 1, 1, 0]
}
# Create DataFrame
df = pd.DataFrame(data)
# Separate features and target
X = df.drop('target', axis=1)
y = df['target']
# Encode the category column using Target Encoder
encoder = TargetEncoder()
X['category'] = encoder.fit_transform(X['category'], y)
# Split the data into training and testing sets
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
# Initialize and train the classifier
clf = RandomForestClassifier()
clf.fit(X_train, y_train)
# Make predictions
y_pred = clf.predict(X_test)
# Calculate accuracy
accuracy = accuracy_score(y_test, y_pred)
print(f"Accuracy: {accuracy}")
Using SKOPS to store into a file and get the unknown data types
from skops import io as sio
sio.dump(encoder, "test.sio")
unknown_type = sio.get_untrusted_types(file="test.sio")
print(unknown_type)
Output of unknown_type
['builtins.object',
'category_encoders.ordinal.OrdinalEncoder',
'category_encoders.target_encoder.TargetEncoder',
'numpy.dtype',
'pandas._libs.index.Int64Engine',
'pandas._libs.index.ObjectEngine',
'pandas._libs.internals.BlockValuesRefs',
'pandas.core.indexes.base.Index',
'pandas.core.internals.managers.SingleBlockManager',
'pandas.core.series.Series']
Use SKOPS to load the file
sio.load("test.sio",trusted=unknown_type)
Obtained the error
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
Cell In[14], line 1
----> 1 sio.load("test.sio",trusted=unknown_type)
File ~/miniconda3/envs/project/lib/python3.10/site-packages/skops/io/_persist.py:152, in load(file, trusted)
150 tree = get_tree(schema, load_context, trusted=trusted)
151 audit_tree(tree)
--> 152 instance = tree.construct()
154 return instance
File ~/miniconda3/envs/project/lib/python3.10/site-packages/skops/io/_audit.py:165, in Node.construct(self)
163 if self._constructed is not UNINITIALIZED:
164 return self._constructed
--> 165 self._constructed = self._construct()
166 return self._constructed
File ~/miniconda3/envs/project/lib/python3.10/site-packages/skops/io/_general.py:420, in ObjectNode._construct(self)
416 if not self.children["attrs"]:
417 # nothing more to do
418 return instance
--> 420 attrs = self.children["attrs"].construct()
421 if attrs is not None:
422 if hasattr(instance, "__setstate__"):
File ~/miniconda3/envs/project/lib/python3.10/site-packages/skops/io/_audit.py:165, in Node.construct(self)
163 if self._constructed is not UNINITIALIZED:
164 return self._constructed
--> 165 self._constructed = self._construct()
166 return self._constructed
File ~/miniconda3/envs/project/lib/python3.10/site-packages/skops/io/_general.py:79, in DictNode._construct(self)
77 key_types = self.children["key_types"].construct()
78 for k_type, (key, val) in zip(key_types, self.children["content"].items()):
---> 79 content[k_type(key)] = val.construct()
80 return content
File ~/miniconda3/envs/project/lib/python3.10/site-packages/skops/io/_audit.py:165, in Node.construct(self)
163 if self._constructed is not UNINITIALIZED:
164 return self._constructed
--> 165 self._constructed = self._construct()
166 return self._constructed
File ~/miniconda3/envs/project/lib/python3.10/site-packages/skops/io/_general.py:79, in DictNode._construct(self)
77 key_types = self.children["key_types"].construct()
78 for k_type, (key, val) in zip(key_types, self.children["content"].items()):
---> 79 content[k_type(key)] = val.construct()
80 return content
File ~/miniconda3/envs/project/lib/python3.10/site-packages/skops/io/_audit.py:165, in Node.construct(self)
163 if self._constructed is not UNINITIALIZED:
164 return self._constructed
--> 165 self._constructed = self._construct()
166 return self._constructed
File ~/miniconda3/envs/project/lib/python3.10/site-packages/skops/io/_general.py:420, in ObjectNode._construct(self)
416 if not self.children["attrs"]:
417 # nothing more to do
418 return instance
--> 420 attrs = self.children["attrs"].construct()
421 if attrs is not None:
422 if hasattr(instance, "__setstate__"):
File ~/miniconda3/envs/project/lib/python3.10/site-packages/skops/io/_audit.py:165, in Node.construct(self)
163 if self._constructed is not UNINITIALIZED:
164 return self._constructed
--> 165 self._constructed = self._construct()
166 return self._constructed
File ~/miniconda3/envs/project/lib/python3.10/site-packages/skops/io/_general.py:79, in DictNode._construct(self)
77 key_types = self.children["key_types"].construct()
78 for k_type, (key, val) in zip(key_types, self.children["content"].items()):
---> 79 content[k_type(key)] = val.construct()
80 return content
File ~/miniconda3/envs/project/lib/python3.10/site-packages/skops/io/_audit.py:165, in Node.construct(self)
163 if self._constructed is not UNINITIALIZED:
164 return self._constructed
--> 165 self._constructed = self._construct()
166 return self._constructed
File ~/miniconda3/envs/project/lib/python3.10/site-packages/skops/io/_general.py:420, in ObjectNode._construct(self)
416 if not self.children["attrs"]:
417 # nothing more to do
418 return instance
--> 420 attrs = self.children["attrs"].construct()
421 if attrs is not None:
422 if hasattr(instance, "__setstate__"):
File ~/miniconda3/envs/project/lib/python3.10/site-packages/skops/io/_audit.py:165, in Node.construct(self)
163 if self._constructed is not UNINITIALIZED:
164 return self._constructed
--> 165 self._constructed = self._construct()
166 return self._constructed
File ~/miniconda3/envs/project/lib/python3.10/site-packages/skops/io/_general.py:179, in TupleNode._construct(self)
175 def _construct(self):
176 # Returns a tuple or a namedtuple instance.
178 cls = gettype(self.module_name, self.class_name)
--> 179 content = tuple(value.construct() for value in self.children["content"])
181 if self.isnamedtuple(cls):
182 return cls(*content)
File ~/miniconda3/envs/project/lib/python3.10/site-packages/skops/io/_general.py:179, in <genexpr>(.0)
175 def _construct(self):
176 # Returns a tuple or a namedtuple instance.
178 cls = gettype(self.module_name, self.class_name)
--> 179 content = tuple(value.construct() for value in self.children["content"])
181 if self.isnamedtuple(cls):
182 return cls(*content)
File ~/miniconda3/envs/project/lib/python3.10/site-packages/skops/io/_audit.py:165, in Node.construct(self)
163 if self._constructed is not UNINITIALIZED:
164 return self._constructed
--> 165 self._constructed = self._construct()
166 return self._constructed
File ~/miniconda3/envs/project/lib/python3.10/site-packages/skops/io/_general.py:113, in ListNode._construct(self)
111 def _construct(self):
112 content_type = gettype(self.module_name, self.class_name)
--> 113 return content_type([item.construct() for item in self.children["content"]])
File ~/miniconda3/envs/project/lib/python3.10/site-packages/skops/io/_general.py:113, in <listcomp>(.0)
111 def _construct(self):
112 content_type = gettype(self.module_name, self.class_name)
--> 113 return content_type([item.construct() for item in self.children["content"]])
File ~/miniconda3/envs/project/lib/python3.10/site-packages/skops/io/_audit.py:165, in Node.construct(self)
163 if self._constructed is not UNINITIALIZED:
164 return self._constructed
--> 165 self._constructed = self._construct()
166 return self._constructed
File ~/miniconda3/envs/project/lib/python3.10/site-packages/skops/io/_general.py:414, in ObjectNode._construct(self)
408 cls = gettype(self.module_name, self.class_name)
410 # Instead of simply constructing the instance, we use __new__, which
411 # bypasses the __init__, and then we set the attributes. This solves the
412 # issue of required init arguments. Note that the instance created here
413 # might not be valid until all its attributes have been set below.
--> 414 instance = cls.__new__(cls) # type: ignore
416 if not self.children["attrs"]:
417 # nothing more to do
418 return instance
File ~/miniconda3/envs/project/lib/python3.10/site-packages/pandas/core/indexes/base.py:526, in Index.__new__(cls, data, dtype, copy, name, tupleize_cols)
523 data = com.asarray_tuplesafe(data, dtype=_dtype_obj)
525 elif is_scalar(data):
--> 526 raise cls._raise_scalar_data_error(data)
527 elif hasattr(data, "__array__"):
528 return cls(np.asarray(data), dtype=dtype, copy=copy, name=name)
File ~/miniconda3/envs/project/lib/python3.10/site-packages/pandas/core/indexes/base.py:5289, in Index._raise_scalar_data_error(cls, data)
5284 @final
5285 @classmethod
5286 def _raise_scalar_data_error(cls, data):
5287 # We return the TypeError so that we can raise it from the constructor
5288 # in order to keep mypy happy
-> 5289 raise TypeError(
5290 f"{cls.__name__}(...) must be called with a collection of some "
5291 f"kind, {repr(data) if not isinstance(data, np.generic) else str(data)} "
5292 "was passed"
5293 )
TypeError: Index(...) must be called with a collection of some kind, None was passed
ENV Python 3.10.14
dill==0.3.8
docker==7.1.0
exceptiongroup @ file:///home/conda/feedstock_root/build_artifacts/exceptiongroup_1720869315914/work
executing @ file:///home/conda/feedstock_root/build_artifacts/executing_1698579936712/work
filelock==3.16.1
flatbuffers==24.3.25
fonttools==4.53.1
frozenlist==1.4.1
fsspec==2024.6.1
gevent==24.2.1
geventhttpclient==2.0.2
google-pasta==0.2.0
greenlet==3.0.3
grpcio==1.64.1
huggingface-hub==0.26.2
humanfriendly==10.0
idna==3.7
imbalanced-learn==0.12.0
importlib-metadata==6.11.0
ipykernel @ file:///home/conda/feedstock_root/build_artifacts/ipykernel_1719845459717/work
ipython @ file:///home/conda/feedstock_root/build_artifacts/ipython_1719582526268/work
jedi @ file:///home/conda/feedstock_root/build_artifacts/jedi_1696326070614/work
jmespath==1.0.1
joblib==1.4.2
jsonschema==4.23.0
jsonschema-specifications==2023.12.1
jupyter_client @ file:///home/conda/feedstock_root/build_artifacts/jupyter_client_1716472197302/work
jupyter_core @ file:///home/conda/feedstock_root/build_artifacts/jupyter_core_1710257277185/work
kiwisolver==1.4.5
lightgbm==4.5.0
matplotlib==3.9.1
matplotlib-inline @ file:///home/conda/feedstock_root/build_artifacts/matplotlib-inline_1713250518406/work
mpmath==1.3.0
multidict==6.0.5
multiprocess==0.70.16
nest_asyncio @ file:///home/conda/feedstock_root/build_artifacts/nest-asyncio_1705850609492/work
numpy==1.26.4
oc==0.2.1
onnx==1.17.0
onnxconverter-common==1.14.0
onnxmltools==1.12.0
onnxruntime==1.16.3
oracledb==2.0.1
packaging @ file:///home/conda/feedstock_root/build_artifacts/packaging_1718189413536/work
pandas==2.2.1
parso @ file:///home/conda/feedstock_root/build_artifacts/parso_1712320355065/work
pathos==0.3.2
patsy==0.5.6
pexpect @ file:///home/conda/feedstock_root/build_artifacts/pexpect_1706113125309/work
pickleshare @ file:///home/conda/feedstock_root/build_artifacts/pickleshare_1602536217715/work
pillow==10.4.0
platformdirs @ file:///home/conda/feedstock_root/build_artifacts/platformdirs_1715777629804/work
pox==0.3.4
ppft==1.7.6.8
prompt_toolkit @ file:///home/conda/feedstock_root/build_artifacts/prompt-toolkit_1718047967974/work
protobuf==3.20.2
psutil @ file:///home/conda/feedstock_root/build_artifacts/psutil_1719274566094/work
ptyprocess @ file:///home/conda/feedstock_root/build_artifacts/ptyprocess_1609419310487/work/dist/ptyprocess-0.7.0-py2.py3-none-any.whl
pure-eval @ file:///home/conda/feedstock_root/build_artifacts/pure_eval_1642875951954/work
pyathena==2.3.2
pycparser==2.22
Pygments @ file:///home/conda/feedstock_root/build_artifacts/pygments_1714846767233/work
pyparsing==3.1.2
python-dateutil @ file:///home/conda/feedstock_root/build_artifacts/python-dateutil_1709299778482/work
python-rapidjson==1.17
pytz==2024.1
PyYAML==6.0.1
pyzmq @ file:///croot/pyzmq_1705605076900/work
referencing==0.35.1
requests==2.32.3
rpds-py==0.19.0
s3transfer==0.10.1
sagemaker==2.226.0
schema==0.7.7
scikit-learn==1.4.0
scipy==1.13.1
six @ file:///home/conda/feedstock_root/build_artifacts/six_1620240208055/work
skl2onnx==1.17.0
skops==0.10.0
smdebug-rulesconfig==1.0.1
stack-data @ file:///home/conda/feedstock_root/build_artifacts/stack_data_1669632077133/work
statsmodels==0.14.2
sympy==1.13.3
tabulate==0.9.0
tblib==3.0.0
tenacity==8.4.1
threadpoolctl==3.5.0
tornado @ file:///home/conda/feedstock_root/build_artifacts/tornado_1717722796999/work
tqdm==4.66.4
traitlets @ file:///home/conda/feedstock_root/build_artifacts/traitlets_1713535121073/work
tritonclient==2.46.0
typing_extensions @ file:///home/conda/feedstock_root/build_artifacts/typing_extensions_1717802530399/work
tzdata==2024.1
urllib3==2.2.2
wcwidth @ file:///home/conda/feedstock_root/build_artifacts/wcwidth_1704731205417/work
xgboost==2.1.2
yarl==1.9.4
zipp @ file:///home/conda/feedstock_root/build_artifacts/zipp_1718013267051/work
zope.event==5.0
zope.interface==6.4.post2
Seems the core of the issue is that we don't support pandas yet, and it's been in my mind to do that. I'll draw up a PR to fix this. Thanks for the perfectly nicely written report BTW.
The issue with this is that there doesn't seem to be any way to save/load a pandas object preserving all the data, w/o an external dependency. Need to figure out what to do here.
Any thoughts @jorisvandenbossche
Can you explain what you mean with "without an external dependency"? What's the exact issue you are running into?
I am not familiar with the approach in skops.io, but pandas has custom support for pickling, so I assume a similar approach should be possible for skops?
Here we avoid using __reduce__ unless we have to, and that's for a very few cases.
Looking at pandas' code, I see that in a lot of cases __reduce__ returns the self's type as a constructor, which is fine, but it also has cases where a constructor function is called.
Generally speaking, we're trying to avoid relying on __reduce__ unless necessary. However, relying on __reduce__ when the constructor is type(self) is okay.
I was hoping for something along the lines of np.save(..., allow_pickle=False). But when going through https://pandas.pydata.org/docs/user_guide/io.html it seems the best option would be parquet which would require at least fastpartuet.
I'm contemplating about adding that as a dependency, since it's small.
Also, WDYT @BenjaminBossan
I could imagine that fully supporting pandas would open up a big can of worms. If it's only about TargetEncoder, I wonder if some special-case solution could not be implemented that takes whatever attribute is responsible and converts it to, say, a Python list (and back to the pandas type for deserialization).
However, when it comes to the provided example, I saw that it uses from category_encoders import TargetEncoder. This is not the sklearn TargetEncoder, so it's not really in scope for skops to support it, right? I tried replacing it with sklearn's TargetEncoder but that runs into several issues with the given code, which I didn't fully resolve. Maybe let's first try to determine if there are any issues with sklearn objects and pandas attributes and fix those first?
If it's only about
TargetEncoder
The rest of Benjamin's comment might make it no longer relevant, but indeed for supporting this use case, I would first check what TargetEncoder exactly needs in terms of pandas support. For example, if it just stores the unique values as an pd.Index object (the error traceback above was about the index), it might be sufficient to support that through the existing support of numpy arrays (if you get the numpy array of an Index and serialize that / reconstruct it from that, I think that should already cover quite some of the use cases)
Looking at pandas' code, I see that in a lot of cases
__reduce__returns the self's type as a constructor, which is fine, but it also has cases where a constructor function is called.Generally speaking, we're trying to avoid relying on
__reduce__unless necessary. However, relying on__reduce__when the constructor istype(self)is okay.
I think the main reason for pandas doing that is because the actual class constructor is often doing too much (like type inference, which we don't want on unpickle)
@BenjaminBossan
This is not the sklearn TargetEncoder, so it's not really in scope for skops to support it, right?
I'd argue that it is. We support xgboost, catboost, quantile_forest, etc. Basically, the sklearn ecosystem, and sklearn compatible estimators are certainly in that ecosystem.
@jorisvandenbossche
For example, if it just stores the unique values as an pd.Index object (the error traceback above was about the index), it might be sufficient to support that through the existing support of numpy arrays (if you get the numpy array of an Index and serialize that / reconstruct it from that, I think that should already cover quite some of the use cases)
That's a very good point. We'd need to check what exactly needs to be persisted. (Maybe @lazarust would be happy to check?)
I think the main reason for pandas doing that is because the actual class constructor is often doing too much (like type inference, which we don't want on unpickle)
One can use this pattern for instance, to avoid calling the constructor:
obj = Obj.__new__(**new_args)
obj.__setstate__(**state_from__getstate__)
I'd argue that it is. We support xgboost, catboost, quantile_forest, etc. Basically, the sklearn ecosystem, and sklearn compatible estimators are certainly in that ecosystem.
Yes, the idea is to also support the sklearn ecosystem. But I think that when it comes to this, it's a matter of trade-offs, i.e. how popular is the package vs how difficult is it to support it. In the end, it's up to you to make the call, however, I would certainly not want to support all of pandas just for category_encoders.TargetEncoder.
I think sooner or later we need to support pandas since it's used in the ecosystem quite a lot, and this TargetEncoder is probably not the only one storing pandas attributes.
However, I'm tempted to use fastparquet as a soft dependency for that, to lower maintenance burden here. Also, I checked polars and there there's a native (to polars) parquet writer/reader which we can use for polars dataframes. That makes our dataframe persistence somewhat consistent.
I'd be happy to look into what needs to persist (it may take me a week or so due to the holidays)!
@adrinjalali
However, I'm tempted to use fastparquet as a soft dependency for that, to lower maintenance burden here.
I think that if we want to support Pandas more broadly, it may make more sense to use parquet rather than fiddling around making sure the right things are getting persisted for each class using pandas data types as attributes.
@adrinjalali I've dug through this a little bit and it seems that category_encoders is using Pandas' Index objects to store unique values.
I agree with @jorisvandenbossche and store it as a numpy array and reconstruct from that would make sense.
What are y'alls thoughts? I do think adding fastparquet would be a more "full featured" solution that would probably require less edge case fixing if other libraries need more than just the Index.