Support Cycling a CombinedStreamingDataset
š Feature
The ability to provide length as an input argument to the CombinedStreamingDataset such that the epoch length is dissociated from the number of samples in the dataset. Same as ParallelStreamingDataset.
Motivation
I want to create a CombinedStreamingDataset that is the weighted combination of StreamingDatasets but be able to specify the number of training steps/cycle the CombinedStreamingDataset arbitrarily. As discussed with @tchaton.
Related to #524
Alternatives
Not sure if this would work but conceptually one workaround might be to wrap the CombinedStreamingDataset with the ParallelStreamingDataset? e.g.
ds1 = StreamingDataset(...)
ds2 = StreamingDataset(...)
cds = CombinedStreamingDataset([ds1, ds2], weights)
pds = ParallelStreamingDataset([cds], length=100)
Implemented a prototype CyclingStreamingDataset which wraps CombinedStreamingDataset
Key Features
The CyclingStreamingDataset is a lightweight wrapper that provides four features for training pipelines:
-
Fixed-Length Iteration (Cycling)
- Iterate for a precise number of samples, regardless of the underlying dataset's size.
- If the underlying dataset is exhausted before the target length is reached, it seamlessly "cycles" back to the beginning to continue providing data.
-
Correct Shuffling on Each Cycle
- Crucially, it ensures that each time the dataset is cycled, the underlying
litdata.StreamingDatasetis re-shuffled. - This is achieved by properly managing and incrementing an
epochcounter, which seeds the shuffling algorithm inlitdata, preventing the model from seeing the same data order in every pass.
- Crucially, it ensures that each time the dataset is cycled, the underlying
-
State Management & Resumability
- Full support for checkpointing and resuming. The dataset implements
state_dict()andload_state_dict(). - If training is interrupted, you can save the state and restore it later. The dataset will resume from the exact sample where it left off, with no data loss or duplication.
- Full support for checkpointing and resuming. The dataset implements
-
Distributed Training Awareness
- Works correctly with LitData's
StreamingDataLoaderwhennum_workers > 0. - It ensures that the total workload is split correctly among all workers, with each worker receiving a unique, non-overlapping shard of the data. This prevents wasted computation and data duplication in distributed settings.
- Works correctly with LitData's
Usage Example
Here is a simple example demonstrating how to build a complete data pipeline using CyclingStreamingDataset with litdata.StreamingDataLoader.
import torch
from torch.utils.data import IterableDataset, get_worker_info
from typing import Dict, Any
import litdata
from litdata import train_test_split, StreamingDataLoader
# Assume 'CyclingStreamingDataset' class definition is available
# Assume 'setup_and_optimize_data' function is available to create data directories
# --- 1. Setup and Create Base Datasets ---
# This step creates the optimized data chunks that litdata reads from.
ds1_dir, ds2_dir = setup_and_optimize_data()
# Create two separate streaming datasets
ds1 = litdata.StreamingDataset(input_dir=ds1_dir, shuffle=True)
ds2 = litdata.StreamingDataset(input_dir=ds2_dir, shuffle=True)
# --- 2. Create Train/Validation Splits ---
# It's best practice to split each dataset individually before combining.
ds1_train, ds1_val = train_test_split(ds1, splits=[0.9, 0.1])
ds2_train, ds2_val = train_test_split(ds2, splits=[0.9, 0.1])
# --- 4. Combine Datasets with a Weighted Ratio using the CyclingStreamingDataset ---
# Combine the training portions of the datasets.
# Here, we sample 80% of our data from ds1_train and 20% from ds2_train.
weights = (0.8, 0.2)
# Define the total number of samples for one training epoch.
# This decouples our training loop from the ~135 samples in the underlying set.
total_training_samples = 10000
batch_size = 32
train_cycle_ds = litdata.CyclingStreamingDataset([ds1_train, ds2_train], weights=weights, length=total_training_samples)
# --- 5. Use in a StreamingDataLoader ---
# The resulting dataset works seamlessly with LitData's StreamingDataLoader.
# It is safe to use with multiple workers.
# Note: `shuffle` and `drop_last` are handled by the dataset itself, not the loader.
dataloader = StreamingDataLoader(
train_cycle_ds,
batch_size=batch_size,
num_workers=4
)
# Now you can iterate for a predictable number of steps
# The length is automatically calculated by the StreamingDataLoader
print(f"Starting training for {len(dataloader)} steps.")
for batch in dataloader:
# Your training logic here...
pass
print("Training epoch complete.")
Prototype implementation and validation script:
from torch.utils.data import get_worker_info
from typing import Dict, Any, Optional, List, Sequence, Iterator
import litdata as ld
import os
import shutil
import csv
# ======================================================================================
# 0. SETUP: CREATE AND OPTIMIZE DATA FOR LITDATA
# ======================================================================================
# `litdata.StreamingDataset` reads from a directory of optimized data chunks.
# First, we create some raw data files (CSVs in this case).
# Second, we use `litdata.optimize` to process these raw files into the required format.
# --- Define processing function for optimization ---
def process_csv(data_path):
"""Process CSV files and yield samples as dictionaries."""
with open(data_path) as f:
# Skip header
next(f)
for row in f:
# Yield samples as dictionaries
yield {"data": int(row.strip())}
def setup_and_optimize_data():
"""Creates raw CSV data and optimizes it for litdata."""
print("--- Step 0: Setting up and optimizing data for litdata ---")
base_dir = "./data"
# Clean up previous runs
if os.path.exists(base_dir):
shutil.rmtree(base_dir)
# Define source and optimized output directories
ds1_source_dir = os.path.join(base_dir, "dataset_1_source")
ds2_source_dir = os.path.join(base_dir, "dataset_2_source")
ds1_optimized_dir = os.path.join(base_dir, "dataset_1_optimized")
ds2_optimized_dir = os.path.join(base_dir, "dataset_2_optimized")
os.makedirs(ds1_source_dir, exist_ok=True)
os.makedirs(ds2_source_dir, exist_ok=True)
# --- Create dummy data files ---
ds1_data_path = os.path.join(ds1_source_dir, "data.csv")
with open(ds1_data_path, "w", newline="") as f:
writer = csv.writer(f)
writer.writerow(["value"])
for i in range(100): # 100 samples
writer.writerow([i])
ds2_data_path = os.path.join(ds2_source_dir, "data.csv")
with open(ds2_data_path, "w", newline="") as f:
writer = csv.writer(f)
writer.writerow(["value"])
for i in range(1000, 1050): # 50 samples
writer.writerow([i])
print("Created raw CSV files.")
# --- Run optimization ---
# This reads the CSVs and writes them into efficient binary chunks.
ld.optimize(
fn=process_csv,
inputs=[ds1_data_path],
output_dir=ds1_optimized_dir,
chunk_bytes="10MB", # Required parameter for litdata.optimize
num_workers=1,
)
ld.optimize(
fn=process_csv,
inputs=[ds2_data_path],
output_dir=ds2_optimized_dir,
chunk_bytes="10MB", # Required parameter for litdata.optimize
num_workers=1,
)
print("Optimization complete. Data is ready for StreamingDataset.\n")
return ds1_optimized_dir, ds2_optimized_dir
class CyclingStreamingDataset(ld.CombinedStreamingDataset):
"""
A production-ready PyTorch IterableDataset wrapper designed to solve a common challenge
in large-scale training: decoupling the number of training steps from the size of the
underlying streaming dataset.
This is particularly useful when working with CombinedStreamingDataset, where multiple
data sources of different sizes are mixed, making the total length of the combined
dataset non-trivial. CyclingStreamingDataset allows you to specify a fixed number of
samples for an epoch, ensuring consistent and predictable training loop lengths.
Key Features:
1. Fixed-Length Iteration (Cycling): Iterate for a precise number of samples, regardless
of the underlying dataset's size. If the underlying dataset is exhausted before the
target length is reached, it seamlessly "cycles" back to the beginning to continue
providing data.
2. Correct Shuffling on Each Cycle: Crucially, it ensures that each time the dataset
is cycled, the underlying StreamingDataset is re-shuffled. This is achieved by properly
managing and incrementing an epoch counter, which seeds the shuffling algorithm in
litdata, preventing the model from seeing the same data order in every pass.
3. State Management & Resumability: Full support for checkpointing and resuming. The
dataset implements state_dict() and load_state_dict(). If training is interrupted,
you can save the state and restore it later. The dataset will resume from the exact
sample where it left off, with no data loss or duplication.
4. Distributed Training Awareness: Works correctly with LitData's StreamingDataLoader when
num_workers > 0. It ensures that the total workload is split correctly among all
workers, with each worker receiving a unique, non-overlapping shard of the data.
"""
def __init__(
self,
datasets: List[ld.StreamingDataset],
length: int,
seed: int = 42,
weights: Optional[Sequence[float]] = None,
batching_method: str = "stratified",
force_override_state_dict: bool = False,
):
"""
Initialize the CyclingStreamingDataset.
Args:
datasets: The list of StreamingDataset to use.
length: The fixed number of samples to yield per epoch.
seed: The random seed to initialize the sampler.
weights: The sampling ratio for the datasets.
batching_method: When set to "stratified" (default), batches will include
samples from all datasets. When "per_stream", batches will
consist of samples from a single dataset, which is selected randomly.
force_override_state_dict: Boolean flag for allowing local arguments to
override a loaded state dict.
"""
# Initialize the parent CombinedStreamingDataset with iterate_over_all=False
# to ensure we can control the cycling behavior
super().__init__(
datasets=datasets,
seed=seed,
weights=weights,
iterate_over_all=False, # We handle cycling ourselves
batching_method=batching_method,
force_override_state_dict=force_override_state_dict,
)
self._target_length = length
self._current_epoch = 0
self._samples_yielded_previously = 0
def get_len(self, num_workers: int, batch_size: int) -> int:
"""Return the fixed length for this dataset."""
self.num_workers = num_workers
self.batch_size = batch_size
return self._target_length
def __len__(self) -> int:
"""Return the fixed length for this dataset."""
return self._target_length
def set_epoch(self, current_epoch: int) -> None:
"""Set the current epoch and increment our internal epoch counter for cycling."""
self._current_epoch = current_epoch
# Call parent's set_epoch to ensure proper shuffling
super().set_epoch(current_epoch)
def state_dict(
self,
num_workers: int,
batch_size: int,
num_samples_yielded: Optional[List[int]] = None,
) -> Dict[str, Any]:
"""
Capture the state of the dataset for resumability.
Returns:
A dictionary containing the cycling state and the underlying dataset state.
"""
# Get the underlying dataset state
underlying_state = super().state_dict(
num_workers, batch_size, num_samples_yielded
)
# Add our cycling-specific state
cycling_state = {
"samples_yielded": self._samples_yielded_previously,
"current_epoch": self._current_epoch,
"target_length": self._target_length,
}
return {
"cycling_state": cycling_state,
"underlying_state": underlying_state,
}
def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
"""
Restore the state of the dataset for resumability.
Args:
state_dict: The state dictionary to restore from.
"""
if not state_dict:
return
# Load cycling-specific state
if "cycling_state" in state_dict:
cycling_state = state_dict["cycling_state"]
self._samples_yielded_previously = cycling_state.get("samples_yielded", 0)
self._current_epoch = cycling_state.get("current_epoch", 0)
# Note: target_length is not restored as it's fixed
# Load underlying dataset state
if "underlying_state" in state_dict:
# The parent CombinedStreamingDataset expects the state in a specific format
# with a "dataset" key containing the individual dataset states
combined_state = {"dataset": state_dict["underlying_state"]}
super().load_state_dict(combined_state)
def __iter__(self) -> Iterator[Any]:
"""
Create an iterator that yields exactly the target number of samples,
cycling through the underlying dataset with proper re-shuffling.
"""
worker_info = get_worker_info()
num_workers = worker_info.num_workers if worker_info else 1
worker_id = worker_info.id if worker_info else 0
# Calculate how many samples this worker is responsible for
worker_total_length = self._target_length // num_workers
if worker_id < self._target_length % num_workers:
worker_total_length += 1
# Calculate how many samples this worker needs to skip to resume correctly
num_to_skip = self._samples_yielded_previously // num_workers
if worker_id < self._samples_yielded_previously % num_workers:
num_to_skip += 1
# Calculate how many samples this worker needs to yield in this run
num_to_yield_this_run = worker_total_length - num_to_skip
if num_to_yield_this_run <= 0:
return # This worker has already yielded all its samples
# Get the underlying iterator
underlying_iter = super().__iter__()
# Skip the required number of samples
for _ in range(num_to_skip):
try:
next(underlying_iter)
except StopIteration:
# If the underlying dataset is exhausted while skipping,
# we need to cycle it. This should not happen in normal operation
# but we handle it gracefully.
self._current_epoch += 1
self.set_epoch(self._current_epoch)
underlying_iter = super().__iter__()
# Yield the required number of samples
samples_yielded_this_run = 0
while samples_yielded_this_run < num_to_yield_this_run:
try:
yield next(underlying_iter)
samples_yielded_this_run += 1
except StopIteration:
# The underlying dataset is exhausted. Cycle it by incrementing epoch
# and creating a new iterator, which will re-shuffle the data.
self._current_epoch += 1
self.set_epoch(self._current_epoch)
underlying_iter = super().__iter__()
def main():
"""Main function to run the litdata exploration."""
# Run the setup
ds1_dir, ds2_dir = setup_and_optimize_data()
# ======================================================================================
# 1. CREATE LITDATA STREAMING DATASETS (Requirement: Shuffling)
# ======================================================================================
print("--- Step 1: Creating litdata.StreamingDataset instances ---")
# Requirement: shuffle=True for both datasets.
# litdata's shuffling is highly optimized for large-scale data.
ds1 = ld.StreamingDataset(input_dir=ds1_dir, shuffle=True)
ds2 = ld.StreamingDataset(input_dir=ds2_dir, shuffle=True)
print(f"Created ds1 with {len(ds1)} samples.")
print(f"Created ds2 with {len(ds2)} samples.")
print("Shuffling is enabled for both datasets.\n")
# ======================================================================================
# 2. COMBINE DATASETS (Requirement: Weighted Combination)
# ======================================================================================
print("--- Step 2: Combining Datasets with Weights ---")
weights = (0.8, 0.2)
# First, split the individual datasets
print("Splitting individual datasets...")
ds1_train, ds1_val = ld.train_test_split(ds1, splits=[0.9, 0.1])
ds2_train, ds2_val = ld.train_test_split(ds2, splits=[0.9, 0.1])
# Then combine the train and validation datasets separately
train_ds = ld.CombinedStreamingDataset(
[ds1_train, ds2_train], weights=weights, iterate_over_all=False
)
val_ds = ld.CombinedStreamingDataset(
[ds1_val, ds2_val], weights=weights, iterate_over_all=False
)
print(f"Combined datasets with weights: {weights}")
print("Verifying the weighted combination using DataLoader batches:")
# Use DataLoader to sample in batches and compute ratios
batch_size = 32
train_dataloader = ld.StreamingDataLoader(
train_ds, batch_size=batch_size, num_workers=0
)
batch_ratios = []
total_samples = 0
dataset_1_total = 0
dataset_2_total = 0
# Sample a fixed number of batches for analysis
num_batches_to_analyze = 10
for i, batch in enumerate(train_dataloader):
if i >= num_batches_to_analyze:
break
# Count samples from each dataset in this batch
batch_dataset_1_count = 0
batch_dataset_2_count = 0
for data_val in batch["data"]:
if data_val < 100:
batch_dataset_1_count += 1
dataset_1_total += 1
else:
batch_dataset_2_count += 1
dataset_2_total += 1
total_samples += 1
# Compute ratio for this batch
batch_total = batch_dataset_1_count + batch_dataset_2_count
if batch_total > 0:
batch_ratio = batch_dataset_1_count / batch_total
batch_ratios.append(batch_ratio)
print(
f" Batch {i + 1}: dataset_1 ratio = {batch_ratio:.3f} ({batch_dataset_1_count}/{batch_total})"
)
# Compute overall statistics
overall_ratio = dataset_1_total / total_samples if total_samples > 0 else 0
avg_batch_ratio = sum(batch_ratios) / len(batch_ratios) if batch_ratios else 0
print(f"\nOverall statistics across {len(batch_ratios)} batches:")
print(f" Average batch ratio (dataset_1): {avg_batch_ratio:.3f}")
print(f" Overall ratio (dataset_1): {overall_ratio:.3f}")
print(f" Expected ratio (dataset_1): {weights[0]:.3f}")
print(f" Total samples analyzed: {total_samples}")
print(f" Samples from dataset_1: {dataset_1_total}")
print(f" Samples from dataset_2: {dataset_2_total}")
# Verify no validation data leakage and compute dataset coverage
print("\n--- Validation Data Leakage Check ---")
# Collect all unique values seen in training data
train_values = set()
for i, batch in enumerate(train_dataloader):
if i >= num_batches_to_analyze:
break
for data_val in batch["data"]:
train_values.add(data_val.item())
# Check validation datasets for any overlap
val_dataloader = ld.StreamingDataLoader(
val_ds, batch_size=batch_size, num_workers=0
)
val_values = set()
for batch in val_dataloader:
for data_val in batch["data"]:
val_values.add(data_val.item())
# Check for overlap
overlap = train_values.intersection(val_values)
if overlap:
print(
f" ā ļø WARNING: Found {len(overlap)} overlapping values between train and validation sets!"
)
print(f" Overlapping values: {sorted(list(overlap))}")
else:
print(
" ā
SUCCESS: No data leakage detected between train and validation sets"
)
# Compute dataset coverage percentages
print("\n--- Dataset Coverage Analysis ---")
# Original dataset sizes
ds1_original_size = len(ds1) # 100
ds2_original_size = len(ds2) # 50
# Count unique values from each original dataset in training data
ds1_train_values = [v for v in train_values if v < 100]
ds2_train_values = [v for v in train_values if v >= 1000]
ds1_coverage = len(ds1_train_values) / ds1_original_size * 100
ds2_coverage = len(ds2_train_values) / ds2_original_size * 100
print(
f" Dataset 1 (0-99): {len(ds1_train_values)}/{ds1_original_size} unique values seen ({ds1_coverage:.1f}% coverage)"
)
print(
f" Dataset 2 (1000-1049): {len(ds2_train_values)}/{ds2_original_size} unique values seen ({ds2_coverage:.1f}% coverage)"
)
# Note: CombinedStreamingDataset with iterate_over_all=False returns None for length
# because the length is variable due to weighted random sampling
# ======================================================================================
# 3. CYCLING DATASET (Requirement: Cycling for Fixed Length)
# ======================================================================================
print("--- Step 3: Wrapping the Training Dataset for Cycling ---")
train_steps = 200
train_cycle_ds = CyclingStreamingDataset(
datasets=[ds1_train, ds2_train], # Use individual datasets, not combined ones
length=train_steps * batch_size,
weights=weights,
batching_method="stratified",
)
print(
f"Wrapped the training dataset to cycle for a fixed length of {len(train_cycle_ds)} samples."
)
print(
f"The underlying train set has variable length due to weighted sampling, but we will draw {train_steps} batches of size {batch_size}.\n"
)
# ======================================================================================
# 4. DATALOADER AND DEMONSTRATION
# ======================================================================================
print("--- Step 4: Creating DataLoader and Demonstrating the Full Pipeline ---")
dataloader = ld.StreamingDataLoader(
train_cycle_ds, batch_size=batch_size, num_workers=0
)
source_counts = {"dataset_1": 0, "dataset_2": 0}
total_samples_processed = 0
total_steps = 0
print(f"Starting training loop with batch_size={batch_size}...")
for i, batch in enumerate(dataloader):
# Determine sources based on data values
batch_sources = []
for data_val in batch["data"]:
source = "dataset_1" if data_val < 100 else "dataset_2"
source_counts[source] += 1
batch_sources.append(source)
total_samples_processed += len(batch["data"])
total_steps += 1
if i < 3:
print(f" Batch {i + 1}:")
print(f" Data: {batch['data']}")
print(f" Sources: {batch_sources}")
print("\n--- Verification ---")
print(f"Total samples processed: {total_samples_processed}")
print(f"Expected samples: {len(train_cycle_ds)}")
print(f"Total steps: {total_steps}")
print(f"Expected steps: {train_steps}")
d1_count = source_counts["dataset_1"]
d2_count = source_counts["dataset_2"]
total_count = d1_count + d2_count
d1_ratio = d1_count / total_count if total_count > 0 else 0
d2_ratio = d2_count / total_count if total_count > 0 else 0
print("\nSource distribution verification:")
print(f" - Samples from dataset_1: {d1_count} (~{d1_ratio:.2%})")
print(f" - Samples from dataset_2: {d2_count} (~{d2_ratio:.2%})")
print(f" - Expected ratio: {weights[0]}/{weights[1]}")
# Dataset coverage analysis for the full training simulation
print("\n--- Final Training Dataset Coverage Analysis ---")
# Collect all unique values seen in the full training simulation
final_train_values = set()
for i, batch in enumerate(dataloader):
for data_val in batch["data"]:
final_train_values.add(data_val.item())
# Count unique values from each original dataset in final training data
ds1_final_values = [v for v in final_train_values if v < 100]
ds2_final_values = [v for v in final_train_values if v >= 1000]
ds1_final_coverage = len(ds1_final_values) / ds1_original_size * 100
ds2_final_coverage = len(ds2_final_values) / ds2_original_size * 100
print(
f" Dataset 1 (0-99): {len(ds1_final_values)}/{ds1_original_size} unique values seen ({ds1_final_coverage:.1f}% coverage)"
)
print(
f" Dataset 2 (1000-1049): {len(ds2_final_values)}/{ds2_original_size} unique values seen ({ds2_final_coverage:.1f}% coverage)"
)
# Compare with initial coverage
print("\n--- Coverage Comparison (Initial vs Final) ---")
print(
f" Dataset 1: {ds1_coverage:.1f}% ā {ds1_final_coverage:.1f}% (change: {ds1_final_coverage - ds1_coverage:+.1f}%)"
)
print(
f" Dataset 2: {ds2_coverage:.1f}% ā {ds2_final_coverage:.1f}% (change: {ds2_final_coverage - ds2_coverage:+.1f}%)"
)
if abs(d1_ratio - weights[0]) < 0.1:
print(
"\nā
Success: The prototype demonstrates all four key requirements with real litdata.StreamingDataset instances."
)
else:
print(
"\nā ļø Note: The observed ratio may differ from the weights due to stochastic sampling."
)
# ======================================================================================
# 5. ADVANCED FEATURE VALIDATION
# ======================================================================================
print("--- Step 5: Advanced Feature Validation ---")
batch_size = 8
total_length = (
96 # Must be divisible by batch_size and num_workers for easy testing
)
# --- Test 1: State Management and Resumability ---
print(
"\n--- Test 1: Validating State Management and Resumability (num_workers=1) ---"
)
# Run 1: Iterate through the first half of the data
print(" Run 1: Processing first half of the data...")
resumable_ds = CyclingStreamingDataset(
datasets=[ds1_train, ds2_train], # Use individual datasets, not combined ones
length=total_length,
weights=weights,
batching_method="stratified",
)
loader1 = ld.StreamingDataLoader(resumable_ds, batch_size=batch_size, num_workers=0)
first_half_samples = []
samples_processed_run1 = 0
for i, batch in enumerate(loader1):
if i >= (total_length // batch_size) // 2:
break
first_half_samples.extend(d.item() for d in batch["data"])
samples_processed_run1 += len(batch["data"])
print(
f" Processed {samples_processed_run1} samples. Simulating training interruption."
)
# Capture state
resumable_ds._samples_yielded_previously = samples_processed_run1
state = resumable_ds.state_dict(num_workers=0, batch_size=batch_size)
print(f" Saved state: {state}")
# Run 2: Create a new instance, load state, and iterate through the second half
print("\n Run 2: Resuming from saved state...")
resumed_ds = CyclingStreamingDataset(
datasets=[ds1_train, ds2_train], # Use individual datasets, not combined ones
length=total_length,
weights=weights,
batching_method="stratified",
)
resumed_ds.load_state_dict(state)
loader2 = ld.StreamingDataLoader(resumed_ds, batch_size=batch_size, num_workers=0)
second_half_samples = []
for batch in loader2:
second_half_samples.extend(d.item() for d in batch["data"])
print(f" Processed {len(second_half_samples)} samples in the second run.")
# Validation for Test 1
print("\n Validation Results for Resumability:")
total_samples_from_both_runs = len(first_half_samples) + len(second_half_samples)
total_unique_samples = len(set(first_half_samples + second_half_samples))
overlap = set(first_half_samples).intersection(set(second_half_samples))
print(
f" Total samples collected across both runs: {total_samples_from_both_runs}"
)
print(f" Total unique samples collected: {total_unique_samples}")
print(f" Expected total samples: {total_length}")
assert total_samples_from_both_runs == total_length, (
"Test Failed: Incorrect number of total samples!"
)
assert total_unique_samples == total_length, (
"Test Failed: Samples were duplicated after resuming!"
)
print(
" ā
SUCCESS: Resumability test passed. The dataset correctly resumed without duplicates or data loss."
)
# --- Test 2: Distributed Training Awareness ---
print("\n--- Test 2: Validating Distributed Training Awareness (num_workers=2) ---")
num_workers = 2
distributed_ds = CyclingStreamingDataset(
datasets=[ds1_train, ds2_train], # Use individual datasets, not combined ones
length=total_length,
weights=weights,
batching_method="stratified",
)
loader_dist = ld.StreamingDataLoader(
distributed_ds, batch_size=batch_size, num_workers=num_workers
)
print(
f" Fetching all {total_length} samples using a DataLoader with {num_workers} workers..."
)
all_samples_distributed = []
for batch in loader_dist:
all_samples_distributed.extend(d.item() for d in batch["data"])
print(f" Collected {len(all_samples_distributed)} total samples.")
# Validation for Test 2
print("\n Validation Results for Distributed Awareness:")
total_unique_samples_dist = len(set(all_samples_distributed))
print(f" Total unique samples collected: {total_unique_samples_dist}")
print(f" Expected total unique samples: {total_length}")
assert len(all_samples_distributed) == total_length, (
"Test Failed: Incorrect total number of samples yielded!"
)
assert total_unique_samples_dist == total_length, (
"Test Failed: Duplicates found! Workers did not receive unique data shards."
)
print(
" ā
SUCCESS: Distributed awareness test passed. Each worker processed a unique data shard."
)
# Clean up the created data directories
print("\nCleaning up data directories...")
if os.path.exists("./data"):
shutil.rmtree("./data")
print("Cleanup complete.")
if __name__ == "__main__":
main()
Which produces the following output:
--- Step 0: Setting up and optimizing data for litdata ---
Created raw CSV files.
Create an account on https://lightning.ai/ to optimize your data faster using multiple nodes and large machines.
Setting multiprocessing start_method to spawn.
Storing the files under /home/userdata/dataset_1_optimized
Setup started with fast_dev_run=False.
Setup finished in 0.001 seconds. Found 1 items to process.
Starting 1 workers with 1 items. The progress bar is only updated when a worker finishes.
Workers are ready ! Starting data processing...
Rank 0 inferred the following `['int']` data format.
Worker 0 is terminating.
Worker 0 is done.
Progress: 100%|āāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāā| 1/1 [00:02<00:00, 2.13s/it]
Workers are finished.
Finished data processing!
Create an account on https://lightning.ai/ to optimize your data faster using multiple nodes and large machines.
Setting multiprocessing start_method to spawn.
Storing the files under /home/user/data/dataset_2_optimized
Setup started with fast_dev_run=False.
Setup finished in 0.0 seconds. Found 1 items to process.
Starting 1 workers with 1 items. The progress bar is only updated when a worker finishes.
Workers are ready ! Starting data processing...
Rank 0 inferred the following `['int']` data format.
Worker 0 is terminating.
Worker 0 is done.
Progress: 100%|āāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāā| 1/1 [00:02<00:00, 2.16s/it]
Workers are finished.
Finished data processing!
Optimization complete. Data is ready for StreamingDataset.
--- Step 1: Creating litdata.StreamingDataset instances ---
Created ds1 with 100 samples.
Created ds2 with 50 samples.
Shuffling is enabled for both datasets.
--- Step 2: Combining Datasets with Weights ---
Splitting individual datasets...
Combined datasets with weights: (0.8, 0.2)
Verifying the weighted combination using DataLoader batches:
Batch 1: dataset_1 ratio = 0.812 (26/32)
Batch 2: dataset_1 ratio = 0.844 (27/32)
Batch 3: dataset_1 ratio = 0.812 (26/32)
Batch 4: dataset_1 ratio = 0.688 (11/16)
Overall statistics across 4 batches:
Average batch ratio (dataset_1): 0.789
Overall ratio (dataset_1): 0.804
Expected ratio (dataset_1): 0.800
Total samples analyzed: 112
Samples from dataset_1: 90
Samples from dataset_2: 22
--- Validation Data Leakage Check ---
ā
SUCCESS: No data leakage detected between train and validation sets
--- Dataset Coverage Analysis ---
Dataset 1 (0-99): 90/100 unique values seen (90.0% coverage)
Dataset 2 (1000-1049): 22/50 unique values seen (44.0% coverage)
--- Step 3: Wrapping the Training Dataset for Cycling ---
Wrapped the training dataset to cycle for a fixed length of 6400 samples.
The underlying train set has variable length due to weighted sampling, but we will draw 200 batches of size 32.
--- Step 4: Creating DataLoader and Demonstrating the Full Pipeline ---
Starting training loop with batch_size=32...
Batch 1:
Data: tensor([ 30, 32, 56, 38, 6, 50, 1033, 25, 21, 57, 0, 60,
45, 89, 3, 53, 63, 85, 1018, 20, 1016, 24, 71, 11,
1031, 77, 42, 35, 1021, 80, 1013, 74])
Sources: ['dataset_1', 'dataset_1', 'dataset_1', 'dataset_1', 'dataset_1', 'dataset_1', 'dataset_2', 'dataset_1', 'dataset_1', 'dataset_1', 'dataset_1', 'dataset_1', 'dataset_1', 'dataset_1', 'dataset_1', 'dataset_1', 'dataset_1', 'dataset_1', 'dataset_2', 'dataset_1', 'dataset_2', 'dataset_1', 'dataset_1', 'dataset_1', 'dataset_2', 'dataset_1', 'dataset_1', 'dataset_1', 'dataset_2', 'dataset_1', 'dataset_2', 'dataset_1']
Batch 2:
Data: tensor([ 5, 1004, 39, 55, 1005, 88, 1006, 47, 27, 18, 9, 69,
40, 48, 26, 70, 37, 68, 87, 51, 43, 1003, 64, 62,
83, 61, 12, 66, 1010, 34, 4, 29])
Sources: ['dataset_1', 'dataset_2', 'dataset_1', 'dataset_1', 'dataset_2', 'dataset_1', 'dataset_2', 'dataset_1', 'dataset_1', 'dataset_1', 'dataset_1', 'dataset_1', 'dataset_1', 'dataset_1', 'dataset_1', 'dataset_1', 'dataset_1', 'dataset_1', 'dataset_1', 'dataset_1', 'dataset_1', 'dataset_2', 'dataset_1', 'dataset_1', 'dataset_1', 'dataset_1', 'dataset_1', 'dataset_1', 'dataset_2', 'dataset_1', 'dataset_1', 'dataset_1']
Batch 3:
Data: tensor([1000, 22, 19, 13, 10, 14, 16, 1025, 1039, 79, 76, 72,
1028, 78, 7, 67, 33, 17, 15, 1007, 52, 1, 1030, 82,
54, 31, 73, 49, 46, 2, 59, 28])
Sources: ['dataset_2', 'dataset_1', 'dataset_1', 'dataset_1', 'dataset_1', 'dataset_1', 'dataset_1', 'dataset_2', 'dataset_2', 'dataset_1', 'dataset_1', 'dataset_1', 'dataset_2', 'dataset_1', 'dataset_1', 'dataset_1', 'dataset_1', 'dataset_1', 'dataset_1', 'dataset_2', 'dataset_1', 'dataset_1', 'dataset_2', 'dataset_1', 'dataset_1', 'dataset_1', 'dataset_1', 'dataset_1', 'dataset_1', 'dataset_1', 'dataset_1', 'dataset_1']
--- Verification ---
Total samples processed: 6400
Expected samples: 6400
Total steps: 200
Expected steps: 200
Source distribution verification:
- Samples from dataset_1: 5145 (~80.39%)
- Samples from dataset_2: 1255 (~19.61%)
- Expected ratio: 0.8/0.2
--- Final Training Dataset Coverage Analysis ---
Dataset 1 (0-99): 90/100 unique values seen (90.0% coverage)
Dataset 2 (1000-1049): 45/50 unique values seen (90.0% coverage)
--- Coverage Comparison (Initial vs Final) ---
Dataset 1: 90.0% ā 90.0% (change: +0.0%)
Dataset 2: 44.0% ā 90.0% (change: +46.0%)
ā
Success: The prototype demonstrates all four key requirements with real litdata.StreamingDataset instances.
--- Step 5: Advanced Feature Validation ---
--- Test 1: Validating State Management and Resumability (num_workers=1) ---
Run 1: Processing first half of the data...
Processed 48 samples. Simulating training interruption.
Saved state: {'cycling_state': {'samples_yielded': 48, 'current_epoch': 1, 'target_length': 96}, 'underlying_state': {'0': {'num_samples_yielded': 46, 'num_workers': 1, 'batch_size': 8, 'current_epoch': 1, 'input_dir_path': '/home/user/data/dataset_1_optimized', 'input_dir_url': None, 'cache_dir_path': None, 'item_loader': None, 'drop_last': False, 'seed': 42, 'world_size': 1, 'shuffle': True, 'subsampled_files': ['chunk-0-0.bin'], 'region_of_interest': [(0, 90)]}, '1': {'num_samples_yielded': 10, 'num_workers': 1, 'batch_size': 8, 'current_epoch': 1, 'input_dir_path': '/home/user/data/dataset_2_optimized', 'input_dir_url': None, 'cache_dir_path': None, 'item_loader': None, 'drop_last': False, 'seed': 42, 'world_size': 1, 'shuffle': True, 'subsampled_files': ['chunk-0-0.bin'], 'region_of_interest': [(0, 45)]}}}
Run 2: Resuming from saved state...
Processed 48 samples in the second run.
Validation Results for Resumability:
Total samples collected across both runs: 96
Total unique samples collected: 96
Expected total samples: 96
ā
SUCCESS: Resumability test passed. The dataset correctly resumed without duplicates or data loss.
--- Test 2: Validating Distributed Training Awareness (num_workers=2) ---
Fetching all 96 samples using a DataLoader with 2 workers...
Collected 96 total samples.
Validation Results for Distributed Awareness:
Total unique samples collected: 96
Expected total unique samples: 96
ā
SUCCESS: Distributed awareness test passed. Each worker processed a unique data shard.
Cleaning up data directories...
Cleanup complete.
I notice that when one dataset exhaust, you reset the whole CombinedDataset rather than resetting the one that runs out. I'm wondering if this is the desired behavior?
Instead of adding a new class, I think we should either make cycling a direct feature of StreamingDataset as suggested in #524, ~~or make ParallelStreamingDataset compatible with CombinedStreamingDataset as suggested by @deependujha in #576. ParallelStreamingDataset already implements the "if one dataset is exhausted then reset that one only instead of the whole thing" logic.~~
~~The second option might be simpler to implement. I can try to give it a go when I have the time.~~
Actually no that might not work