SDV
SDV copied to clipboard
Validate that the columns passed to the a constraint exist in the data
We don't validate that the constraint is passed only existing columns, which may crash with mysterious error messages. Below is such an example:
import sdv
from sdv.constraints import Unique, OneHotEncoding
import pandas as pd
import numpy as np
from sdv.tabular import GaussianCopula
data = pd.DataFrame({
'a': [1, 0],
'b': [0, 1],
})
cnt = OneHotEncoding(['a','b','c'])
model = GaussianCopula(constraints=[cnt])
model.fit(data)
model.sample(10)
which produces the following traceback
Input In [10], in <cell line: 12>()
10 cnt = OneHotEncoding(['a','b','c'])
11 model = GaussianCopula(constraints=[cnt])
---> 12 model.fit(data)
13 model.sample(10)
File ~/Desktop/SDV/sdv/tabular/base.py:147, in BaseTabularModel.fit(self, data)
144 LOGGER.debug('Fitting %s to table %s; shape: %s', self.__class__.__name__,
145 self._metadata.name, data.shape)
146 if not self._metadata_fitted:
--> 147 self._metadata.fit(data)
149 self._num_rows = len(data)
151 LOGGER.debug('Transforming table %s; shape: %s', self._metadata.name, data.shape)
File ~/Desktop/SDV/sdv/metadata/table.py:588, in Table.fit(self, data)
585 data = self._anonymize(data)
587 LOGGER.info('Fitting constraints for table %s', self.name)
--> 588 constrained = self._fit_transform_constraints(data)
589 extra_columns = set(constrained.columns) - set(data.columns)
591 LOGGER.info('Fitting HyperTransformer for table %s', self.name)
File ~/Desktop/SDV/sdv/metadata/table.py:447, in Table._fit_transform_constraints(self, data)
445 def _fit_transform_constraints(self, data):
446 for constraint in self._constraints:
--> 447 data = constraint.fit_transform(data)
449 return data
File ~/Desktop/SDV/sdv/constraints/base.py:294, in Constraint.fit_transform(self, table_data)
283 def fit_transform(self, table_data):
284 """Fit this Constraint to the data and then transform it.
285
286 Args:
(...)
292 Transformed data.
293 """
--> 294 self.fit(table_data)
295 return self.transform(table_data)
File ~/Desktop/SDV/sdv/constraints/base.py:152, in Constraint.fit(self, table_data)
149 self._fit(table_data)
151 if self.fit_columns_model and len(self.constraint_columns) > 1:
--> 152 data_to_model = table_data[list(self.constraint_columns)]
153 self._hyper_transformer = HyperTransformer(default_data_type_transformers={
154 'categorical': 'OneHotEncodingTransformer',
155 })
156 transformed_data = self._hyper_transformer.fit_transform(data_to_model)
File ~/opt/anaconda3/envs/sdv/lib/python3.9/site-packages/pandas/core/frame.py:3511, in DataFrame.__getitem__(self, key)
3509 if is_iterator(key):
3510 key = list(key)
-> 3511 indexer = self.columns._get_indexer_strict(key, "columns")[1]
3513 # take() does not accept boolean indexers
3514 if getattr(indexer, "dtype", None) == bool:
File ~/opt/anaconda3/envs/sdv/lib/python3.9/site-packages/pandas/core/indexes/base.py:5782, in Index._get_indexer_strict(self, key, axis_name)
5779 else:
5780 keyarr, indexer, new_indexer = self._reindex_non_unique(keyarr)
-> 5782 self._raise_if_missing(keyarr, indexer, axis_name)
5784 keyarr = self.take(indexer)
5785 if isinstance(key, Index):
5786 # GH 42790 - Preserve name from an Index
File ~/opt/anaconda3/envs/sdv/lib/python3.9/site-packages/pandas/core/indexes/base.py:5845, in Index._raise_if_missing(self, key, indexer, axis_name)
5842 raise KeyError(f"None of [{key}] are in the [{axis_name}]")
5844 not_found = list(ensure_index(key)[missing_mask.nonzero()[0]].unique())
-> 5845 raise KeyError(f"{not_found} not in index")
KeyError: "['c'] not in index"