litdata icon indicating copy to clipboard operation
litdata copied to clipboard

Add support for multi sample item in optimize and yielding from the _getitem_ of the StreamingDataset

Open tchaton opened this issue 1 year ago • 10 comments

🚀 Feature

Motivation

It would be great to be able to create a batch of sub sample from a given sample. Right now, you can't do that.

However a user could support this.


def optimize(...):

    sample = 
	return MultiSample(sample, num_samples=X)

Under the hood, we know this sample could be used to generate multiple random samples.


class MyStreamingDataset(StreamingDataset):

	def __getitem__(self, index, sample_id):
			sample  = super().__getitem__(index)
            
            # do some transformation
            return  data

A use case would be image detection where each image can be used to generate multiple sub boxes and we might want to have them as different training samples.

Pitch

Alternatives

Additional context

tchaton avatar Aug 08 '24 19:08 tchaton

This issue has been automatically marked as stale because it has not had recent activity. It will be closed if no further activity occurs. Thank you for your contributions.

stale[bot] avatar Apr 16 '25 05:04 stale[bot]

Let's keep this. interesting one.

bhimrazy avatar Jun 08 '25 16:06 bhimrazy

I think this issue will be interesting for new contributors once we have a clear breakdown of it. I’ll also brainstorm a bit more on it.

One simple idea that comes to mind: generate x variants of the same image using minor augmentations—such as flip, rotation, intensity change, crop, etc.—similar to how torchvision.transforms works. We could even leverage torchvision under the hood.

This could also be extended to support other use cases, like object detection, segmentation, etc as mentioned in the issue.

Curious to hear your thoughts on this: @tchaton @deependujha

bhimrazy avatar Jun 08 '25 16:06 bhimrazy

really nice issue.

@bhimrazy my understanding of the issue is, optimize dataset will contain only one sample, but while streaming, same sample will be yielded multiple times (along with sample id).

Now, it's upto the user to decide if they want to use it with torchvision, or add noise to the audio signal, or something else.

Also, I think, it'll be more useful to either override __next__ or a feature often requested, transform method in SD, that users can override.

I expect it to be doable for new contributors, but we have to consider length of streaming dataset.

edit:

my bad, __getitem__ makes more sense. As __next__ also internally calls __getitem__.

deependujha avatar Jun 09 '25 23:06 deependujha

Ah, I see now—this actually happens at stream time, not during optimization as I initially thought. Thanks for clarifying that, @deependujha.

bhimrazy avatar Jun 10 '25 05:06 bhimrazy

Also, I think shuffling in this case will be interesting.

My approach for this will be:

  • Add an additional property in index.json file called sample_count, which will contain how many samples of each item you want, rather than using something like MultiSample.
  • Optimize dataset as usual.
  • While streaming, inflate dataset length to original_len * sample_count. And the index being read can be computed by:
original_index = inflated_index // sample_count
sample_index = inflated_index % sample_count
  • We also have to think of some clever way to get the chunk number. (if not immediately available)

which will be then passed into transform fn. #618

And then users can apply their transformations there.

deependujha avatar Jun 11 '25 07:06 deependujha

Hey @deependujha, @tchaton, @bhimrazy Is this still available to take up?

VijayVignesh1 avatar Oct 23 '25 19:10 VijayVignesh1

Sure @VijayVignesh1, go ahead! Excited to see your contribution 🚀

deependujha avatar Oct 24 '25 04:10 deependujha

Is there a reason why you suggested modifying the index.json? The solution I had in my mind is:

  1. Optimize the dataset normally.
  2. Add is_multisample boolean as a class parameter which the users can toggle.
  3. Get a list of self.transform and modify the length of dataset as original length * len(self.transform).
  4. Compute the original index as
original_index = inflated_index // sample_count
sample_index = inflated_index % sample_count
  1. Use the specific transform function corresponding to sample_index. Since we are only working with the original index, we don't need to worry about chunk number right? @deependujha

VijayVignesh1 avatar Oct 24 '25 13:10 VijayVignesh1

Hi @VijayVignesh1,

Sure, your approach makes sense to me. My initial thought around modifying index.json came from how Thomas initially framed the idea (via the optimize function), but your plan feels cleaner and more straightforward.

That said, modifying index.json could still be handy in an extended scenario: for example, when users want to assign different sample_count values per item. In that case, having the count persisted in index.json might make it easier to reproduce or reload the dataset later.

For the general case though, your is_multisample flag + transform-based inflation is definitely the simpler path.

But, I’d prefer a sample_count parameter (defaulting to 1). That way, we can easily handle both single and multi-sample cases without extra conditionals.

deependujha avatar Oct 29 '25 07:10 deependujha