SDV
SDV copied to clipboard
PAR DiagnosticReport not 1.0 with float categorical columns
Environment Details
Please indicate the following details about the environment in which you found the bug:
- SDV version:
- Python version:
- Operating System:
Error Description
When running PAR with categorical columns that are floats, PAR does not stick to the original categories when sampling. This leads to a very low diagnostic score for 'Data Validity' due to the CategoryAdherence metric failing.
Steps to reproduce
from sdv.datasets.demo import download_demo
from sdv.sequential import PARSynthesizer
from sdv.evaluation.single_table import run_diagnostic
data, metadata = download_demo('sequential', 'nasdaq100_2019')
data['category'] = [100.0 if i % 2 == 0 else 50.0 for i in data.index]
metadata.add_column('category', sdtype='categorical')
synth = PARSynthesizer(metadata)
synth.fit(data)
sampled = synth.sample(2)
report = run_diagnostic(data, sampled, metadata)
Workaround
If anyone is running into this, here is a suggested workaround:
- Identify any
categoricalcolumns (in the metadata) that are actually represented as numbers in your data (ints, floats, etc.) - Cast these columns as objects before inputting them into the PARSynthesizer.
- At the end when you get synthetic data, cast them back as ints, floats, etc.
Here is a code snippet that accomplishes the below. Replace the list CAT_COLUMN_NAMES with the list of your column names.
CAT_COLUMN_NAMES = ['ColA', 'ColB', ... ]
data = <your pandas DataFrame>
metadata = <your SingleTableMetadata object>
# cast the categorical columns to strings
for col_name in CAT_COLUMN_NAMES:
data[col_name] = data[col_name].astype('object')
# now proceed with modeling and sampling as usual
synthesizer = PARSynthesizer(metadata)
synthesizer.fit(data)
synthetic_data = synthesizer.sample(num_sequences=10)
# (optional) cast the categorical columns back to floats
for col_name in CAT_COLUMN_NAMES:
try:
synthetic_data[col_name] = synthetic_data[col_name].astype('float')
except:
print('Column name', col_name, 'could not be converted back to a float')
continue