Fix input type checking for indices in `classification_subset`
classification_subset expects a list of indices as its input, but currently it also works with a tensor of indices. However, when you want to concatenate the subset dataset with another dataset, you get the error below in FlatData.concat():
TypeError: unsupported operand type(s) for +: 'Tensor' and 'list'
We either need to force the user to use a Python list, or automatically convert the LongTensor to a Python list to avoid such errors later on.
I don't think this is a bug. Are you asking for static type checking? We should add the type annotations (not sure if they are already there) but I don't think we should do static type checking.
I don't think this is a bug. Are you asking for static type checking? We should add the type annotations (not sure if they are already there) but I don't think we should do static type checking.
The part that the subset operation works with the LongTensor datatype is not a bug on its own, but it can cause problems later if the user wants to perform operations like concatenation operations. I was thinking about changing the indices datatype to Union[List[int], LongTensor] and converting LongTensors to a list in the subset function (simply via .tolist()) to avoid potential issues.
We can change it to a Sequence[int] if:
- the new method works with all
Sequence[int]s, not just tensors. It can still have a special case in the function body, but not in the signature - there are no performance regressions