Image cast_storage very slow for arrays (e.g. numpy, tensors)
Update: see comments below
Describe the bug
Operations that save an image from a path are very slow.
I believe the reason for this is that the image data (numpy) is converted into pyarrow format but then back to python using .pylist() before being converted to a numpy array again.
pylist is already slow but used on a multi-dimensional numpy array such as an image it takes a very long time.
From the trace below we can see that __arrow_array__ takes a long time.
It is currently also called in get_inferred_type, this should be removable #6781 but doesn't change the underyling issue.
The conversion to pyarrow and back also leads to the numpy array having type int64 which causes a warning message because the image type excepts uint8.
However, originally the numpy image array was in uint8.
Steps to reproduce the bug
from PIL import Image
import numpy as np
import datasets
import cProfile
image = Image.fromarray(np.random.randint(0, 255, (2048, 2048, 3), dtype=np.uint8))
image.save("test_image.jpg")
ds = datasets.Dataset.from_dict(
{"image": ["test_image.jpg"]},
features=datasets.Features({"image": datasets.Image(decode=True)}),
)
# load as numpy array, e.g. for further processing with map
# same result as map returning numpy arrays
ds.set_format("numpy")
cProfile.run("ds.map(writer_batch_size=1, load_from_cache_file=False)", "restats")
Fri Apr 5 14:56:17 2024 restats
66817 function calls (64992 primitive calls) in 33.382 seconds
Ordered by: cumulative time
List reduced from 1073 to 20 due to restriction <20>
ncalls tottime percall cumtime percall filename:lineno(function)
46/1 0.000 0.000 33.382 33.382 {built-in method builtins.exec}
1 0.000 0.000 33.382 33.382 <string>:1(<module>)
1 0.000 0.000 33.382 33.382 arrow_dataset.py:594(wrapper)
1 0.000 0.000 33.382 33.382 arrow_dataset.py:551(wrapper)
1 0.000 0.000 33.379 33.379 arrow_dataset.py:2916(map)
4 0.000 0.000 33.327 8.332 arrow_dataset.py:3277(_map_single)
1 0.000 0.000 33.311 33.311 arrow_writer.py:465(write)
2 0.000 0.000 33.311 16.656 arrow_writer.py:423(write_examples_on_file)
1 0.000 0.000 33.311 33.311 arrow_writer.py:527(write_batch)
2 14.484 7.242 33.260 16.630 arrow_writer.py:161(__arrow_array__)
1 0.001 0.001 16.438 16.438 arrow_writer.py:121(get_inferred_type)
1 0.000 0.000 14.398 14.398 threading.py:637(wait)
1 0.000 0.000 14.398 14.398 threading.py:323(wait)
8 14.398 1.800 14.398 1.800 {method 'acquire' of '_thread.lock' objects}
4/2 0.000 0.000 4.337 2.169 table.py:1800(wrapper)
2 0.000 0.000 4.337 2.169 table.py:1950(cast_array_to_feature)
2 0.475 0.238 4.337 2.169 image.py:209(cast_storage)
9 2.583 0.287 2.583 0.287 {built-in method numpy.array}
2 0.000 0.000 1.284 0.642 image.py:319(encode_np_array)
2 0.000 0.000 1.246 0.623 image.py:301(image_to_bytes)
Expected behavior
The numpy image data should be passed through as it will be directly consumed by pillow to convert it to bytes.
As an example one can replace list_of_np_array_to_pyarrow_listarray(data) in __arrow_array__ with just out = data as a test.
We have to change cast_storage of the Image feature so it handles the passed through data (& if to handle type before)
bytes_array = pa.array(
[encode_np_array(arr)["bytes"] if arr is not None else None for arr in storage],
type=pa.binary(),
)
Leading to the following:
Fri Apr 5 15:44:27 2024 restats
66419 function calls (64595 primitive calls) in 0.937 seconds
Ordered by: cumulative time
List reduced from 1023 to 20 due to restriction <20>
ncalls tottime percall cumtime percall filename:lineno(function)
47/1 0.000 0.000 0.935 0.935 {built-in method builtins.exec}
2/1 0.000 0.000 0.935 0.935 <string>:1(<module>)
2/1 0.000 0.000 0.934 0.934 arrow_dataset.py:594(wrapper)
2/1 0.000 0.000 0.934 0.934 arrow_dataset.py:551(wrapper)
2/1 0.000 0.000 0.934 0.934 arrow_dataset.py:2916(map)
4 0.000 0.000 0.933 0.233 arrow_dataset.py:3277(_map_single)
1 0.000 0.000 0.883 0.883 arrow_writer.py:466(write)
2 0.000 0.000 0.883 0.441 arrow_writer.py:424(write_examples_on_file)
1 0.000 0.000 0.882 0.882 arrow_writer.py:528(write_batch)
2 0.000 0.000 0.877 0.439 arrow_writer.py:161(__arrow_array__)
4/2 0.000 0.000 0.877 0.439 table.py:1800(wrapper)
2 0.000 0.000 0.877 0.439 table.py:1950(cast_array_to_feature)
2 0.009 0.005 0.877 0.439 image.py:209(cast_storage)
2 0.000 0.000 0.868 0.434 image.py:335(encode_np_array)
2 0.000 0.000 0.856 0.428 image.py:317(image_to_bytes)
2 0.000 0.000 0.822 0.411 Image.py:2376(save)
2 0.000 0.000 0.822 0.411 PngImagePlugin.py:1233(_save)
2 0.000 0.000 0.822 0.411 ImageFile.py:517(_save)
2 0.000 0.000 0.821 0.411 ImageFile.py:545(_encode_tile)
589 0.803 0.001 0.803 0.001 {method 'encode' of 'ImagingEncoder' objects}
This is of course only a test as it passes through all numpy arrays irrespective of if they should be an image.
Also I guess cast_storage is meant for casting pyarrow storage exclusively.
Converting to pyarrow array seems like a good solution as it also handles pytorch tensors etc., maybe there is a more efficient way to create a PIL image from a pyarrow array?
Not sure how this should be handled but I would be happy to help if there is a good solution.
Environment info
-
datasetsversion: 2.18.1.dev0 - Platform: Linux-6.7.11-200.fc39.x86_64-x86_64-with-glibc2.38
- Python version: 3.12.2
-
huggingface_hubversion: 0.22.2 - PyArrow version: 15.0.2
- Pandas version: 2.2.1
-
fsspecversion: 2024.3.1
This may be a solution that only changes cast_storage of Image.
However, I'm not totally sure that the assumptions hold that are made about the ListArray.
elif pa.types.is_list(storage.type):
from .features import Array3DExtensionType
def get_shapes(arr):
shape = ()
while isinstance(arr, pa.ListArray):
len_curr = len(arr)
arr = arr.flatten()
len_new = len(arr)
shape = shape + (len_new // len_curr,)
return shape
def get_dtypes(arr):
dtype = storage.type
while hasattr(dtype, "value_type"):
dtype = dtype.value_type
return dtype
arrays = []
for i, is_null in enumerate(storage.is_null()):
if not is_null.as_py():
storage_part = storage.take([i])
shape = get_shapes(storage_part)
dtype = get_dtypes(storage_part)
extension_type = Array3DExtensionType(shape=shape, dtype=str(dtype))
array = pa.ExtensionArray.from_storage(extension_type, storage_part)
arrays.append(array.to_numpy().squeeze(0))
else:
arrays.append(None)
bytes_array = pa.array(
[encode_np_array(arr)["bytes"] if arr is not None else None for arr in arrays],
type=pa.binary(),
)
path_array = pa.array([None] * len(storage), type=pa.string())
storage = pa.StructArray.from_arrays(
[bytes_array, path_array], ["bytes", "path"], mask=bytes_array.is_null()
)
(Edited): to handle nulls
Notably this doesn't change anything about the passing through of data or other things, just in the Image class.
Seems quite fast:
Fri Apr 5 17:55:51 2024 restats
63818 function calls (61995 primitive calls) in 0.812 seconds
Ordered by: cumulative time
List reduced from 1051 to 20 due to restriction <20>
ncalls tottime percall cumtime percall filename:lineno(function)
47/1 0.000 0.000 0.810 0.810 {built-in method builtins.exec}
2/1 0.000 0.000 0.810 0.810 <string>:1(<module>)
2/1 0.000 0.000 0.809 0.809 arrow_dataset.py:594(wrapper)
2/1 0.000 0.000 0.809 0.809 arrow_dataset.py:551(wrapper)
2/1 0.000 0.000 0.809 0.809 arrow_dataset.py:2916(map)
3 0.000 0.000 0.807 0.269 arrow_dataset.py:3277(_map_single)
1 0.000 0.000 0.760 0.760 arrow_writer.py:589(finalize)
1 0.000 0.000 0.760 0.760 arrow_writer.py:423(write_examples_on_file)
1 0.000 0.000 0.759 0.759 arrow_writer.py:527(write_batch)
1 0.001 0.001 0.754 0.754 arrow_writer.py:161(__arrow_array__)
2/1 0.000 0.000 0.719 0.719 table.py:1800(wrapper)
1 0.000 0.000 0.719 0.719 table.py:1950(cast_array_to_feature)
1 0.006 0.006 0.718 0.718 image.py:209(cast_storage)
1 0.000 0.000 0.451 0.451 image.py:361(encode_np_array)
1 0.000 0.000 0.444 0.444 image.py:343(image_to_bytes)
1 0.000 0.000 0.413 0.413 Image.py:2376(save)
1 0.000 0.000 0.413 0.413 PngImagePlugin.py:1233(_save)
1 0.000 0.000 0.413 0.413 ImageFile.py:517(_save)
1 0.000 0.000 0.413 0.413 ImageFile.py:545(_encode_tile)
397 0.409 0.001 0.409 0.001 {method 'encode' of 'ImagingEncoder' objects}
Also encounter this problem. Has been strugging with it for a long time...
This actually applies to all arrays (numpy or tensors like in torch), not only from external files.
import numpy as np
import datasets
ds = datasets.Dataset.from_dict(
{"image": [np.random.randint(0, 255, (2048, 2048, 3), dtype=np.uint8)]},
features=datasets.Features({"image": datasets.Image(decode=True)}),
)
ds.set_format("numpy")
ds = ds.map(load_from_cache_file=False)