lance icon indicating copy to clipboard operation
lance copied to clipboard

lance.torch.data.LanceDataset row ids broken on main

Open jacketsj opened this issue 1 year ago • 0 comments

Row IDs in the torch integration seem to be broken on main at the moment. Here's a repro:

import lance
import pyarrow as pa
import pyarrow.compute as pc
import time
import lance.torch.data as ltd
import torch
from torch.utils.data import DataLoader

dims = 128
nrows = 10_000

def next_batch(batch_size, offset):
    values = pc.random(dims * batch_size).cast('float32')
    return pa.table({
        'id': pa.array([offset + j for j in range(batch_size)]),
        'vector': pa.FixedSizeListArray.from_arrays(values, dims),
    }).to_batches()[0]

def batch_iter(num_rows):
    i = 0
    while i < num_rows:
        batch_size = min(10_000, num_rows - i)
        yield next_batch(batch_size, i)
        i += batch_size

schema = next_batch(1, 0).schema

ds_path = "./temp-test-torch-broken.lance"
ds = lance.write_dataset(batch_iter(nrows), ds_path, schema=schema, mode="overwrite", use_legacy_format = False)

tdataset = ltd.LanceDataset(
    ds,
    columns=["vector"],
    batch_size=1024,
    batch_readahead=8,
    with_row_id=True,
)
tdataloader = DataLoader(tdataset)
for tbatch in tdataloader:
    tvecs = tbatch["vector"]
    trow_ids = tbatch["_rowid"]

Output error for reference (truncated slightly):

File [~/.local/lib/python3.10/site-packages/lance/torch/data.py:64](http://localhost:8888/home/jacketsj/.local/lib/python3.10/site-packages/lance/torch/data.py#line=63), in _to_tensor(batch, uint64_as_int64, hf_converter)
     58     del np_tensor
     59 elif (
     60     pa.types.is_integer(arr.type)
     61     or pa.types.is_floating(arr.type)
     62     or pa.types.is_boolean(arr.type)
     63 ):
---> 64     tensor = torch.from_numpy(arr.to_numpy(zero_copy_only=True))
     66     if uint64_as_int64 and tensor.dtype == torch.uint64:
     67         tensor = tensor.to(torch.int64)

TypeError: can't convert np.ndarray of type numpy.uint64. The only supported types are: float64, float32, float16, complex64, complex128, int64, int32, int16, int8, uint8, and bool.

The above gives no error on 0.16.1, so this is a regression.

jacketsj avatar Aug 28 '24 18:08 jacketsj