Loading `.tfrecords` files that require a deserialization method
🐛 Describe the bug
Hi,
I have a dataset in TFRecords format and am trying to move to TorchData's API for loading tfrecords files. This is the minimal example:
datapipe1 = IterableWrapper(['path/to/my/tfrecords/file.tfrecords'])
datapipe2 = FileOpener(datapipe1, mode="b")
tfrecord_loader_dp = datapipe2.load_from_tfrecord()
for d in tfrecord_loader_dp:
pass
It fails, as the datapipe does not know how to properly deserialize the tfrecord file.
File ~/.conda/envs/bend/lib/python3.10/site-packages/torchdata/datapipes/iter/util/tfrecordloader.py:245, in TFRecordLoaderIterDataPipe.__iter__(self)
243 pathname, data_stream = data
244 try:
--> 245 for example_bytes in iterate_tfrecord_file(data_stream):
246 example = example_pb2.SequenceExample() # type: ignore
247 example.ParseFromString(example_bytes) # type: ignore
File ~/.conda/envs/bend/lib/python3.10/site-packages/torchdata/datapipes/iter/util/tfrecordloader.py:83, in iterate_tfrecord_file(data)
81 (length,) = struct.unpack("<Q", length_bytes)
82 if length > len(data_bytes):
---> 83 data_bytes = data_bytes.zfill(int(length * 1.5))
84 data_bytes_view = memoryview(data_bytes)[:length]
85 if data.readinto(data_bytes_view) != length:
OverflowError: Python int too large to convert to C ssize_t
This exception is thrown by __iter__ of TFRecordLoaderIterDataPipe(datapipe=FileOpenerIterDataPipe, length=-1, spec=None)
In the legacy tensorflow codebase, I would have to specify a function to deserialize the tfrecord, by doing
import tensorflow as tf
import tensorflow_datasets as tfds
dataset = tf.data.Dataset.from_tensor_slices(['path/to/my/tfrecords/file.tfrecords'])
dataset = dataset.interleave(lambda fp: tf.data.TFRecordDataset(fp, compression_type=compression_type), cycle_length=1, block_length=1, num_parallel_calls=tf.data.AUTOTUNE)
features = tfds.features.FeaturesDict.from_json(json.load(json_file)) # this file contains info about the .tfrecords file i'm trying to load
dataset = dataset.map(features.deserialize_example, num_parallel_calls=tf.data.AUTOTUNE)
iterator = dataset.as_numpy_iterator()
for d in iterator:
pass #this works, returning a dict of tf tensors
The problem is basically that I have to deserialize the tfrecord, but I can't apply anything to the TFRecordLoaderIterDataPipe before it fails.
Is there a workaround? I tried just wrapping the tensorflow dataset object in an IterableWrapper, but the tensorflow dataset can't be pickled so fails in DataLoader2.
Thanks!
Versions
Collecting environment information... PyTorch version: 2.0.1+cu117 Is debug build: False CUDA used to build PyTorch: 11.7 ROCM used to build PyTorch: N/A
OS: Ubuntu 20.04.5 LTS (x86_64) GCC version: (Ubuntu 9.4.0-1ubuntu1~20.04.2) 9.4.0 Clang version: Could not collect CMake version: version 3.27.4 Libc version: glibc-2.31
Python version: 3.10.12 (main, Jul 5 2023, 18:54:27) [GCC 11.2.0] (64-bit runtime) Python platform: Linux-5.15.0-1027-aws-x86_64-with-glibc2.31 Is CUDA available: False CUDA runtime version: No CUDA CUDA_MODULE_LOADING set to: N/A GPU models and configuration: No CUDA Nvidia driver version: No CUDA cuDNN version: No CUDA HIP runtime version: N/A MIOpen runtime version: N/A Is XNNPACK available: True
CPU: Architecture: x86_64 CPU op-mode(s): 32-bit, 64-bit Byte Order: Little Endian Address sizes: 46 bits physical, 48 bits virtual CPU(s): 16 On-line CPU(s) list: 0-15 Thread(s) per core: 2 Core(s) per socket: 8 Socket(s): 1 NUMA node(s): 1 Vendor ID: GenuineIntel CPU family: 6 Model: 85 Model name: Intel(R) Xeon(R) Platinum 8259CL CPU @ 2.50GHz Stepping: 7 CPU MHz: 2499.994 BogoMIPS: 4999.98 Hypervisor vendor: KVM Virtualization type: full L1d cache: 256 KiB L1i cache: 256 KiB L2 cache: 8 MiB L3 cache: 35.8 MiB NUMA node0 CPU(s): 0-15 Vulnerability Itlb multihit: KVM: Mitigation: VMX unsupported Vulnerability L1tf: Mitigation; PTE Inversion Vulnerability Mds: Vulnerable: Clear CPU buffers attempted, no microcode; SMT Host state unknown Vulnerability Meltdown: Mitigation; PTI Vulnerability Mmio stale data: Vulnerable: Clear CPU buffers attempted, no microcode; SMT Host state unknown Vulnerability Retbleed: Vulnerable Vulnerability Spec store bypass: Vulnerable Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization Vulnerability Spectre v2: Mitigation; Retpolines, STIBP disabled, RSB filling, PBRSB-eIBRS Not affected Vulnerability Srbds: Not affected Vulnerability Tsx async abort: Not affected Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ss ht syscall nx pdpe1gb rdtscp lm constant_tsc rep_good nopl xtopology nonstop_tsc cpuid aperfmperf tsc_known_freq pni pclmulqdq ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand hypervisor lahf_lm abm 3dnowprefetch invpcid_single pti fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid mpx avx512f avx512dq rdseed adx smap clflushopt clwb avx512cd avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves ida arat pku ospke
Versions of relevant libraries: [pip3] numpy==1.24.3 [pip3] torch==2.0.1 [pip3] torchdata==0.6.1 [pip3] torchvision==0.15.2 [pip3] triton==2.0.0 [conda] numpy 1.24.3 pypi_0 pypi [conda] torch 2.0.1 pypi_0 pypi [conda] torchdata 0.6.1 pypi_0 pypi [conda] torchvision 0.15.2 pypi_0 pypi [conda] triton 2.0.0 pypi_0 pypi
I think dataloader and datapipes are going to be removed in future, but in the meantime, are there any workarounds?