Fix batch size gpu
Why are these changes needed?
map_groups passes batch_size=None to map_batches, this raises error when using GPU. batch_size should be None when using map_groups as it uses a groupby key. Removing the exception fixes map_groups on GPU and it should be down to the developer to set batch_size appropriately.
Related issue number
Checks
- [ ] I've signed off every commit(by using the -s flag, i.e.,
git commit -s) in this PR. - [ ] I've run
scripts/format.shto lint the changes in this PR. - [ ] I've included any doc changes needed for https://docs.ray.io/en/master/.
- [ ] I've added any new APIs to the API Reference. For example, if I added a
method in Tune, I've added it in
doc/source/tune/api/under the corresponding.rstfile.
- [ ] I've added any new APIs to the API Reference. For example, if I added a
method in Tune, I've added it in
- [ ] I've made sure the tests are passing. Note that there might be a few flaky tests, see the recent failures at https://flakey-tests.ray.io/
- Testing Strategy
- [ ] Unit tests
- [ ] Release tests
- [ ] This PR is not tested :(
I've removed test_strict_require_batch_size_for_gpu, alternatively you could identify calls from map_groups vs map_batches and raise if not map_groups call but I think the developer should have the agency to change batch_size from DEFAULT_BATCH_SIZE
@jjyao please can I request review
hi @Joseph-Sarsfield , thanks for you contribution. But I don't really understand the problem. Could you give a concrete example? Also, it'd be nice if this change only affects groupby, but not general map_batches.
hi @Joseph-Sarsfield , thanks for you contribution. But I don't really understand the problem. Could you give a concrete example? Also, it'd be nice if this change only affects groupby, but not general map_batches.
Hi Raulchen thanks for your response,
If I call map_groups it currently returns a call to map_batches with batch_size set to None. Within map_batches _apply_batch_size is called with batch_size None which throws ValueError if using GPU
return sorted_ds.map_batches( group_fn, batch_size=None, compute=compute, batch_format=batch_format, fn_args=fn_args, fn_kwargs=fn_kwargs, **ray_remote_args, )
def _apply_batch_size(
given_batch_size: Optional[Union[int, Literal["default"]]], use_gpu: bool
) -> Optional[int]:
if use_gpu and (not given_batch_size or given_batch_size == "default"):
raise ValueError(
"batch_size must be provided to map_batches when requesting GPUs. "
"The optimal batch size depends on the model, data, and GPU used. "
"It is recommended to use the largest batch size that doesn't result "
"in your GPU device running out of memory. You can view the GPU memory "
"usage via the Ray dashboard."
)
elif given_batch_size == "default":
return ray.data.context.DEFAULT_BATCH_SIZE
else:
return given_batch_size
thanks for the update. The reason of adding this limitation is that we want to let users to explicitly set the batch_size for GPU, to avoid expected GRAM OOMs. So we want to keep this. One solution is to allow setting batch_size in map_groups. Or maybe you can also workaround this in your code by doing GPU inference before groupby?
Agree with @raulchen . This is a good fix, but we should make it apply only for map_groups() and not all map operations.
thanks for the update. The reason of adding this limitation is that we want to let users to explicitly set the batch_size for GPU, to avoid expected GRAM OOMs. So we want to keep this. One solution is to allow setting batch_size in map_groups. Or maybe you can also workaround this in your code by doing GPU inference before groupby?
The batch_size is None in map_groups as each batch is the block so I think this needs to stay None, can we pass a flag from map_groups to bypass the exception?
def map_groups: The batch is the entire block, because we have batch_size=None for map_batches() below. Note we set batch_size=None here, so it will use the entire block as a batch, which ensures that each group will be contained within a batch in entirety.
Hi @raulchen @ericl, I've updated to pass a flag to determine calls from map_groups.
Thanks for catching this @Joseph-Sarsfield; we actually fixed this a la #45305