graphnet icon indicating copy to clipboard operation
graphnet copied to clipboard

ParquetDataset class missing 'labels' argument

Open OscarBarreraGithub opened this issue 1 year ago • 1 comments

In the example script 02_train_tito_model.py, there are labels passed to the dataloader

'labels': { "direction": Direction( azimuth_key="injection_azimuth", zenith_key="injection_zenith" ) }

The SQLiteDataset handles these labels fine - but when loading with ParquetDataset, we get

TypeError: ParquetDataset.__init__() got an unexpected keyword argument 'labels'

This occurs when using the GraphNetDataModule, although I believe the error lies within the ParquetDataset class itself as it does not pass labels as one of its arguments (despite the Dataset class having it within its init as "labels: Dictionary of labels to be added to the dataset").

OscarBarreraGithub avatar Jul 03 '24 21:07 OscarBarreraGithub

@OscarBarreraGithub thanks for sharing this.

It looks like it's just a simple oversight on the list of arguments for ParquetDataset - its missing the entry labels: Optional[Dict[str, Any]] = None , which should also be passed to super() here.

Fixing this would require just two lines of code:

  1. Add the argument to ParquetDataset: labels: Optional[Dict[str, Any]] = None
  2. Propagate the argument to super() here

RasmusOrsoe avatar Jul 04 '24 09:07 RasmusOrsoe

closed by #730

RasmusOrsoe avatar Aug 13 '24 12:08 RasmusOrsoe