Cifar dataloader does not work
The cifar dataloader no longer works properly with jax algorithms using jax.jit. I did not test to see if pytorch algorithms still work with cifar.
Description
When running jax_nadamw_full_budget.py optimizer with the cifar workload, an error is thrown which says len(shards) = 128 but len(devices) = 8 Here is the relevant log:
I0918 18:15:28.313870 140360727609984 submission_runner.py:359] Starting training loop. Traceback (most recent call last): File "/algorithmic-efficiency/submission_runner.py", line 869, in
app.run(main) File "/usr/local/lib/python3.11/site-packages/absl/app.py", line 308, in run _run_main(main, args) File "/usr/local/lib/python3.11/site-packages/absl/app.py", line 254, in _run_main sys.exit(main(argv)) ^^^^^^^^^^ File "/algorithmic-efficiency/submission_runner.py", line 834, in main score = score_submission_on_workload( ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/algorithmic-efficiency/submission_runner.py", line 747, in score_submission_on_workload score, _ = train_once( ^^^^^^^^^^^ File "/algorithmic-efficiency/submission_runner.py", line 375, in train_once batch = data_selection( ^^^^^^^^^^^^^^^ File "/algorithmic-efficiency/algorithms/baselines/self_tuning/jax_nadamw_full_budget.py", line 446, in data_selection batch = next(input_queue) ^^^^^^^^^^^^^^^^^ File "/usr/local/lib/python3.11/site-packages/flax/jax_utils.py", line 147, in prefetch_to_device enqueue(size) # Fill up the buffer. ^^^^^^^^^^^^^ File "/usr/local/lib/python3.11/site-packages/flax/jax_utils.py", line 145, in enqueue queue.append(jax.tree_util.tree_map(_prefetch, data)) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/usr/local/lib/python3.11/site-packages/jax/_src/tree_util.py", line 361, in tree_map return treedef.unflatten(f(*xs) for xs in zip(*all_leaves)) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/usr/local/lib/python3.11/site-packages/jax/_src/tree_util.py", line 361, in return treedef.unflatten(f(*xs) for xs in zip(*all_leaves)) ^^^^^^ File "/usr/local/lib/python3.11/site-packages/flax/jax_utils.py", line 141, in _prefetch return jax.device_put_sharded(list(xs), devices) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/usr/local/lib/python3.11/site-packages/jax/_src/api.py", line 2636, in device_put_sharded raise ValueError(f"len(shards) = {len(shards)} must equal " ValueError: len(shards) = 128 must equal len(devices) = 8. 2025-09-18 18:15:29.477815: W tensorflow/core/kernels/data/cache_dataset_ops.cc:916] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset will be discarded. This can happen if you have an input pipeline similar to dataset.cache().take(k).repeat(). You should usedataset.take(k).cache().repeat()instead.
Steps to Reproduce
- In algorithms/baselines/self_tuning/jax_nadamw_full_budget.py add the following two lines in get_batch_size function:
elif workload_name == 'cifar': return 128
- Then run the cifar workload in docker:
python submission_runner.py
--framework=jax
--workload=cifar
--experiment_dir=/experiment_runs
--experiment_name=jax_debug_cifar
--data_dir=/data
--tuning_ruleset=self
--submission_path=algorithms/baselines/self_tuning/jax_nadamw_full_budget.py
Source or Possible Fix
I think the cifar is not an officially supported workload, but it can be useful for debugging. So once it is not too much trouble we should fix this.