ParquetDataset class missing 'labels' argument
In the example script 02_train_tito_model.py, there are labels passed to the dataloader
'labels': { "direction": Direction( azimuth_key="injection_azimuth", zenith_key="injection_zenith" ) }
The SQLiteDataset handles these labels fine - but when loading with ParquetDataset, we get
TypeError: ParquetDataset.__init__() got an unexpected keyword argument 'labels'
This occurs when using the GraphNetDataModule, although I believe the error lies within the ParquetDataset class itself as it does not pass labels as one of its arguments (despite the Dataset class having it within its init as "labels: Dictionary of labels to be added to the dataset").
@OscarBarreraGithub thanks for sharing this.
It looks like it's just a simple oversight on the list of arguments for ParquetDataset - its missing the entry labels: Optional[Dict[str, Any]] = None , which should also be passed to super() here.
Fixing this would require just two lines of code:
- Add the argument to
ParquetDataset:labels: Optional[Dict[str, Any]] = None - Propagate the argument to
super()here
closed by #730