TypeError during sample_conditions() due to progress bar (`tqdm`)
Environment Details
Please indicate the following details about the environment in which you found the bug:
- SDV version: sdv-0.15.0
- Python version: Python 3.8.10
- Operating System: Ubuntu 18.04, Docker image tensorflow/tensorflow:2.9.1-gpu
Error Description
I trained a CTGAN and want to generate about 450000 samples.
However, during sample_conditions() an exception occurs that seems to be related to the progress bar implementation. See below for the traceback.
Numpy version is 1.22.4.
Steps to reproduce
- Install sdv
- Train CTGAN model
- Run
sample_conditions() - Crash
The code is really straight forward. Below is a simplified version of it.
model = CTGAN(field_transformers={"stalling": "categorical" })
model.fit(data)
model.save(save_path)
stalling_condition = Condition({"stalling": 1}, num_rows=450000)
stalling_samples = model.sample_conditions(conditions=[ stalling_condition ])
The relevant output:
0%| | 0/451702 [00:00<?, ?it/s]
Sampling conditions: 0%| | 0/451702 [00:00<?, ?it/s]
Sampling conditions: 0%| | 0/451702 [05:23<?, ?it/s]
Error: Sampling terminated. Partial results are stored in a temporary file: .sample.csv.temp. This file will be overridden the next time you sample. Please rename the file if you wish to save these results.
Traceback (most recent call last):
File "scripts/ctgan_rf.py", line 81, in <module>
sys.exit(main())
File "scripts/ctgan_rf.py", line 72, in main
upsample(model, ds)
File "scripts/ctgan_rf.py", line 48, in upsample
stalling_samples = model.sample_conditions(conditions=[stalling_condition])
File "/usr/local/lib/python3.8/dist-packages/sdv/tabular/base.py", line 667, in sample_conditions
return self._sample_conditions(
File "/usr/local/lib/python3.8/dist-packages/sdv/tabular/base.py", line 715, in _sample_conditions
handle_sampling_error(output_file_path == TMP_FILE_NAME, output_file_path, error)
File "/usr/local/lib/python3.8/dist-packages/sdv/tabular/utils.py", line 185, in handle_sampling_error
raise sampling_error
File "/usr/local/lib/python3.8/dist-packages/sdv/tabular/base.py", line 703, in _sample_conditions
sampled = progress_bar_wrapper(_sample_function, num_rows, 'Sampling conditions')
File "/usr/local/lib/python3.8/dist-packages/sdv/tabular/utils.py", line 157, in progress_bar_wrapper
return function(progress_bar)
File "/usr/local/lib/python3.8/dist-packages/sdv/tabular/base.py", line 689, in _sample_function
sampled_for_condition = self._sample_with_conditions(
File "/usr/local/lib/python3.8/dist-packages/sdv/tabular/base.py", line 612, in _sample_with_conditions
sampled_rows = self._conditionally_sample_rows(
File "/usr/local/lib/python3.8/dist-packages/sdv/tabular/base.py", line 386, in _conditionally_sample_rows
sampled_rows = self._sample_batch(
File "/usr/local/lib/python3.8/dist-packages/sdv/tabular/base.py", line 345, in _sample_batch
if progress_bar:
TypeError: __bool__ should return bool, returned numpy.bool_
Hi @mphe, thanks for the details and stack trace. We'll try to replicate and update this issue as we learn more.
I'm wondering if there is some bad interaction between tqdm and tensorflow. Would it be possible to test this in an environment without tensorflow to see if the error persists? (You can try training on a subset of the data if it's taking too long.)
In the meantime, if you are only passing a single condition into the sample_conditions method, you can try the following workaround:
stalling_samples = model.sample_conditions(
conditions=[stalling_condition],
max_tries=1,
batch_size_per_try= 450000 # set this to num_rows or higher
)
Setting max_tries=1 will effectively turn off the progress bar whenever there is a single condition. Unfortunately, you may receive fewer than requested number rows because the invalid rows will be filtered out and there will be no retries.
Thanks for the workaround, that works so far. I wrote my own retry loop around it.
I just tested in a clean Python 3.9.9 environment where I only ran pip install sdv, but the error persists.
Here is the package list in case it is helpful.
certifi==2022.6.15
charset-normalizer==2.1.0
copulas==0.7.0
ctgan==0.5.1
cycler==0.11.0
deepecho==0.3.0.post1
Faker==9.9.1
fonttools==4.34.0
graphviz==0.20
idna==3.3
joblib==1.1.0
kiwisolver==1.4.3
llvmlite==0.38.1
matplotlib==3.5.2
numba==0.55.2
numpy==1.22.4
packaging==21.3
pandas==1.4.3
Pillow==9.2.0
psutil==5.9.1
pyparsing==3.0.9
python-dateutil==2.8.2
pyts==0.12.0
pytz==2022.1
PyYAML==5.4.1
rdt==0.6.4
requests==2.28.1
scikit-learn==1.1.1
scipy==1.7.3
sdmetrics==0.5.0
sdv==0.15.0
six==1.16.0
text-unidecode==1.3
threadpoolctl==3.1.0
torch==1.12.0
torchvision==0.13.0
tqdm==4.64.0
typing_extensions==4.3.0
urllib3==1.26.9
Hello @mphe, we've just released SDV 0.16.0.
We were unable to replicate your exact bug, but in this version we've significantly cleaned up the sampling and we think it may have fixed what you encountered. Could you try again to see if it still persists?
Hello, I am having the same problem. I am using SDV 0.16.0, trying to run this within a loop in parallel using multiprocessing. I have tried your suggestion, but i am still getting the error message.
My conditions are as follows: condition = Condition({"Outcome": unique_outcomes[x]}) new_data3 = model.sample_conditions(conditions=[condition],max_tries_per_batch=1, batch_size=fake_data_size)
The error is as follows: Sampling terminated. Partial results are stored in a temporary file: .sample.csv.temp. This file will be overridden the next time you sample. Please rename the file if you wish to save these results. Then: [Errno 2] The system cannot find the file specified: '.sample.csv.temp'
Any ideas what i can do? Thanks
Hi @ASJB, are you seeing the same error as the initial issue?
TypeError: __bool__ should return bool, returned numpy.bool_
If you are observing a different error, would you be able to file a new issue for it? The root cause may be different.
I'll update this issue's title to be more specific about the cause.
Hi @mphe, I'm closing this issue as it has been inactive for some time. If you're still observing this on the latest version of the SDV (0.16.0), please respond and I can reopen the issue to investigate it.
@ASJB, please feel free to submit a new issue with your error. We can investigate that separately.
Sorry for the late reply. I can confirm, the bug no longer appears. Thanks for fixing!