sidechainnet icon indicating copy to clipboard operation
sidechainnet copied to clipboard

error value in batch

Open omarkhaled-28 opened this issue 2 years ago • 1 comments

it gives error when trying to put traindata into batch `batch = next(iter(dataloader['train'])) print("Protein IDs\n ", batch.ids) print("Sequences\n ", batch.seqs.shape) print("Evolutionary Data\n ", batch.evolutionary.shape) print("Secondary Structure\n ", batch.secondary.shape) print("Angle Data\n ", batch.angles.shape) print("Coordinate Data\n ", batch.coords.shape) print("X-ray Resolution\n ", batch.resolutions) print("Integer sequence") print("\tShape:", batch.seqs_int.shape) print("\tEx:", batch.seqs_int[0,:3])

print("1-hot sequence") print("\tShape:", batch.seqs.shape) print("\tEx:\n", batch.seqs[0,:3])this is the output once i run it --------------------------------------------------------------------------- ValueError Traceback (most recent call last) in <cell line: 1>() ----> 1 batch = next(iter(dataloader['train'])) 2 print("Protein IDs\n ", batch.ids) 3 print("Sequences\n ", batch.seqs.shape) 4 print("Evolutionary Data\n ", batch.evolutionary.shape) 5 print("Secondary Structure\n ", batch.secondary.shape)

3 frames /usr/local/lib/python3.10/dist-packages/torch/_utils.py in reraise(self) 692 # instantiate since we don't know how to 693 raise RuntimeError(msg) from None --> 694 raise exception 695 696

ValueError: Caught ValueError in DataLoader worker process 0. Original Traceback (most recent call last): File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/_utils/worker.py", line 308, in _worker_loop data = fetcher.fetch(index) File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/_utils/fetch.py", line 54, in fetch return self.collate_fn(data) File "/usr/local/lib/python3.10/dist-packages/sidechainnet/dataloaders/collate.py", line 85, in collate_fn padded_crds = pad_for_batch(coords, max_batch_len, 'crd') File "/usr/local/lib/python3.10/dist-packages/sidechainnet/dataloaders/collate.py", line 186, in pad_for_batch c = np.concatenate((item, z), axis=0) File "<array_function internals>", line 180, in concatenate ValueError: all the input arrays must have same number of dimensions, but the array at index 0 has 3 dimension(s) and the array at index 1 has 2 dimension(s)`

omarkhaled-28 avatar Feb 04 '24 20:02 omarkhaled-28

Can you provide more context for your code? The following works for me:

import sidechainnet as scn
dataloaders = scn.load(casp_version=12, casp_thinning=30, with_pytorch="dataloaders")
batch = next(iter(dataloaders['train']))

dkoes avatar Feb 09 '24 21:02 dkoes