Pytorch lighting Fabric + lit data + DDP hangs when finishing epoch
🐛 Bug
I am training CLIP using pytorch lighing fabric + litdata on a distributed set up (4 nodes each 4 GPUs). I noticed that when finishing the 1st epoch the training dataloaders hang for some nodes.
The image bellow shows fabric.print()doing the logging on 4 nodes before finishing an epoch (I print every 25 steps). Only one rank successfully finishes the rest hang, otherwise the message ++++ Epoch: 0 completed ++++ will appear 4 times, once in each node). I shared relevant parts of the code bellow, any help would be appreciated
Additional context + Parts of Training code
- The data chunks are downloaded in each node before fabric launches via popen the subprocesses for all the devices (so the dataloading is local) not streamed
- I tested
mosaicml-streamingon the same fabric training code and it works without issues. - I put here in another issue the VertexAI cluster environment I use to configure Fabric
StreamingDataset setup
I am using the following script to load the litdata:
class ImageCaptionDataset(StreamingDataset):
"""
Main dataset to retrieve imager-caption optimized dataset
"""
@staticmethod
def get_zero_shot_one_hot(zero_shot_attributes: List[int]):
one_hot_encoded = torch.zeros(len(CLASSES), dtype=torch.float)
one_hot_encoded[zero_shot_attributes] = 1.0
return one_hot_encoded
def __getitem__(self, idx: int) -> Any:
_, image_bytes, text_ids, mask, zero_shot_attr = super().__getitem__(idx)
image = Image.open(io.BytesIO(image_bytes))
input_ids = torch.tensor(np.frombuffer(text_ids, dtype=np.int64))
attention_mask = torch.tensor(np.frombuffer(mask, dtype=np.int64))
zero_shot_attr = self.get_zero_shot_one_hot(literal_eval(zero_shot_attr))
return image, input_ids, attention_mask, zero_shot_attr
def collate_fn(batch, processor):
"""Arrange the batch into a dictionary of tensors for HF
"""
images = processor(images=[ex[0] for ex in batch], return_tensors="pt")
return {
"pixel_values": images["pixel_values"],
"input_ids": torch.stack([ex[1] for ex in batch]),
"attention_mask": torch.stack([ex[2] for ex in batch]),
"labels": torch.stack([ex[3] for ex in batch])
}
def get_dataloader(split: str, config: configs.ExperimentConfig):
dataset = ImageCaptionDataset(
input_dir=os.path.join(config.local_data_path, split),
shuffle=True if split == "train" else False,
)
# Image transformation function
processor = CLIPProcessor.from_pretrained(
config.training_config.pre_trained_backbone
)
return StreamingDataLoader(
dataset,
batch_size=config.dataset_config.batch_size,
shuffle=True if split == "train" else False,
num_workers=config.dataset_config.num_workers,
collate_fn=partial(
collate_fn,
processor=processor
),
drop_last=True,
pin_memory=config.dataset_config.pin_memory,
)
Am I setting properly the dataloader here?, I checked and litGPT uses torch DataLoader instead of StreamingDataloader
Here I show what I managed to monitor on how CPU and RAM looks like for an entire epoch
You can see how instead of jumping again to load the samples of the test set, it hangs ...
Training using Fabric
I put here some parts of my training script which basically follows open-clip implementation but using Lighting Fabric, maybe I am doing something wrong when the epoch is finishing? I noticed state is not saved at the end of the epoch on the last iteration (just in the middle of training on checkpoint_step):
def set_fabric(logger: CSVLogger, config: configs.ExperimentConfig):
"""Set fabric using VertexAI cluster environment
"""
strategy = DDPStrategy(
static_graph=True,
cluster_environment=utils.infrastructure.VertexAICluster()
)
fabric = Fabric(
accelerator=config.training_config.accelerator,
strategy=strategy,
devices=config.training_config.num_gpus,
num_nodes=config.training_config.num_nodes,
precision=config.training_config.precision,
loggers=logger,
)
fabric.launch()
# utils.data_streaming.set_streaming_env_vars(fabric) <- necessary for mosaicml-streaming
return fabric
def save_model_weights(model: nn.Module, fabric: Fabric, experiment_config: configs.ExperimentConfig):
"""Save only model weights to remote storage,
these are the ones that need to be called at inference time
"""
if fabric.global_rank == 0:
model.backbone.save_pretrained(experiment_config.model_path)
def save_state(
model: nn.Module,
state: dict,
fabric: Fabric,
experiment_config: configs.ExperimentConfig
):
"""
Save state of the training session every checkpoint_step or at the end of the epoch
"""
if state["current_step"] != 0 and (
state["current_step"] % experiment_config.checkpoint_step == 0
or state["iteration"] == state["num_batches_per_epoch"] - 1
):
fabric.save(
os.path.join(
experiment_config.checkpoint_path,
f"{state['current_epoch']:04d}-{state['current_step']:04d}-state.ckpt"
),
state
)
save_model_weights(model, fabric, experiment_config)
def train_epoch(
train_loader: torch.utils.data.DataLoader,
model: nn.Module,
state: dict,
zero_shot_attributes: dict,
optimizer: torch.optim.Optimizer,
criterion: nn.Module,
scheduler: torch.optim.lr_scheduler,
fabric: Fabric,
metrics_dict: dict,
experiment_config: configs.ExperimentConfig
):
"""Train the model for one epoch (an entire training cycle)
"""
model.train()
accum_samples, accum_features = [], {}
num_accumulated = 0
for idx, batch in enumerate(train_loader, start=state["start_iteration"]):
time_start = time.perf_counter()
batch = fabric.to_device(batch)
optimizer.zero_grad()
with torch.no_grad():
output = model(batch)
output.pop("logit_scale", None)
for key, value in output.items():
if key not in accum_features:
accum_features[key] = [value]
else:
accum_features[key].append(value)
accum_samples.append(batch)
num_accumulated += 1
state["iteration"] += 1
if (idx + 1) % experiment_config.training_config.accum_freq > 0:
continue
# compute embeddings on zero shot attributes
with torch.no_grad():
zero_shot_embeddings = model(zero_shot_attributes=zero_shot_attributes)
# compute loss aggregating the embeddings accumulated
optimizer.zero_grad()
for j in range(num_accumulated):
batch = accum_samples[j]
output = model(batch)
# keep scale partially
inputs = {"logit_scale": output.pop("logit_scale")}
for name, features in accum_features.items():
accumulated = accum_features[name]
inputs[name] = torch.cat(accumulated[:j] + [output[name]] + accumulated[j + 1:])
loss = criterion(**inputs, fabric=fabric)
del inputs
fabric.backward(loss)
optimizer.step()
scheduler.step()
# update metrics
if check_if_is_log_step(state, experiment_config):
update_metrics(
"train", loss, accum_features, accum_samples, zero_shot_embeddings, metrics_dict, fabric
)
# reset accumulation
accum_samples, accum_features = [], {}
num_accumulated = 0
with torch.no_grad():
model.backbone.logit_scale.clamp_(0, math.log(100))
timing = time.perf_counter() - time_start
# log results
if check_if_is_log_step(state, experiment_config):
log_batch_metrics(
timing, metrics_dict, fabric, state, scheduler.get_last_lr(), experiment_config
)
# save state if needed
save_state(model, state, fabric, experiment_config)
state["current_step"] += 1
fabric.print(f"++++ Epoch {state['current_epoch']} completed ++++")
@torch.no_grad()
def validate_epoch(
test_loader: torch.utils.data.DataLoader,
model: nn.Module,
state: dict,
zero_shot_attributes: dict,
criterion: nn.Module,
metrics_dict: dict,
fabric: Fabric,
experiment_config: configs.ExperimentConfig
):
fabric.barrier()
fabric.print(f"++++ Validating epoch {state['current_epoch']} ++++")
model.eval()
for idx, batch in enumerate(test_loader):
batch = fabric.to_device(batch)
output = model(batch)
zero_shot_embeddings = model(zero_shot_attributes=zero_shot_attributes)
loss = criterion(**output, fabric=fabric)
update_metrics("test", loss, output, batch, zero_shot_embeddings, metrics_dict, fabric)
if idx % (experiment_config.verbose_step * 10) == 0 or idx == len(test_loader) - 1:
fabric.print(
f"+++ Epoch: {state['current_epoch']:04d} "
f"| Test Step: {idx}/{len(test_loader)}"
f"| Test Loss: {metrics_dict['test_loss'].compute().item()}"
f"| Test Label Hit Rate {metrics_dict['test_label_hit_rate'].compute().item()}"
f"| Test Precision@k {metrics_dict['test_precision_at_k'].compute().item()}"
f"| Test Cosine Similarity {metrics_dict['test_cosine_similarity'].compute().item()}"
f" +++"
)
def fit(
train_loader: torch.utils.data.DataLoader,
test_loader: torch.utils.data.DataLoader,
model: nn.Module,
state: dict,
zero_shot_attributes: dict,
optimizer: torch.optim.Optimizer,
criterion: nn.Module,
scheduler: torch.optim.lr_scheduler,
fabric: Fabric,
metrics_dict: dict,
experiment_config: configs.ExperimentConfig
):
"""Fit the model"""
fabric.print(f"+++ Number of raw training batches: {len(train_loader)}")
fabric.print(f"+++ Number of raw test batches: {len(test_loader)}")
for epoch in range(state["current_epoch"], experiment_config.dataset_config.num_epochs):
train_epoch(
train_loader,
model,
state,
zero_shot_attributes,
optimizer,
criterion,
scheduler,
fabric,
metrics_dict,
experiment_config
)
validate_epoch(
test_loader,
model,
state,
zero_shot_attributes,
criterion,
metrics_dict,
fabric,
experiment_config
)
log_epoch_metrics(metrics_dict, fabric)
state["start_iteration"] = 0
state["current_epoch"] += 1
# save final version of the mode & metadata
save_model_weights(model, fabric, experiment_config)
save_metadata(fabric, experiment_config)
fabric.logger.finalize("success")
def get_current_epoch_iteration(
train_loader: torch.utils.data.DataLoader,
fabric: Fabric,
experiment_config: configs.ExperimentConfig
) -> int:
"""
Get epoch iteration on resume for enumerate to start
"""
state_dict = train_loader.state_dict()
num_train_samples = experiment_config.dataset_config.num_train_samples
if "num_samples_yielded" in state_dict:
# litdata format
samples_seen = state_dict["num_samples_yielded"]
current_epoch = state_dict["current_epoch"]
else:
# mosaic-ml format
samples_seen = state_dict["sample_in_epoch"]
current_epoch = state_dict["epoch"]
iteration = (samples_seen * len(train_loader)) // num_train_samples
fabric.print(f"++++ Resuming epoch {current_epoch} from iteration: {iteration} ++++")
return iteration
def get_initial_state(model, optimizer, train_loader, scheduler):
"""Get the state of the session"""
return {
"current_epoch": 0,
"current_step": 0, # model update
"iteration": 0, # dataset-wise
"model": model.state_dict(),
"optimizer": optimizer.state_dict(),
"train_loader": train_loader.state_dict(),
"scheduler": scheduler.state_dict()
}
def get_state(
model: nn.Module,
optimizer: torch.optim.Optimizer,
train_loader: torch.utils.data.DataLoader,
scheduler: torch.optim.lr_scheduler,
fabric: Fabric,
experiment_config: configs.ExperimentConfig
):
"""Loads a checkpoint from a given file into state from remote storage
"""
current_checkpoints = sorted(utils.io.glob_gcs(
os.path.join(experiment_config.checkpoint_path, "*-state.ckpt")
), reverse=True)
if current_checkpoints:
fabric.print(
f"++++ Resuming training using state from: {os.path.basename(current_checkpoints[0])} ++++"
)
state = fabric.load(current_checkpoints[0].replace("gs://", "/gcs/"))
model.load_state_dict(state["model"])
optimizer.load_state_dict(state["optimizer"])
train_loader.load_state_dict(state["train_loader"])
scheduler.load_state_dict(state["scheduler"])
# modify start iteration state in case of resuming training
state["start_iteration"] = get_current_epoch_iteration(train_loader, fabric, experiment_config)
else:
fabric.print(f"++++ Starting Training from epoch: 0 ++++")
state = get_initial_state(model, optimizer, train_loader, scheduler)
state["start_iteration"] = 0
# extra metadata
state["num_batches_per_epoch"] = len(train_loader)
state["num_steps_per_epoch"] = len(train_loader) // experiment_config.training_config.accum_freq
return state
def launch_training(experiment_config: configs.ExperimentConfig):
"""Set up the training environment, launcher and data loaders
"""
logger = CSVLogger(
root_dir=experiment_config.artifacts_path,
name="logs", version="clip"
)
fabric = set_fabric(logger, experiment_config)
fabric.print("++++ Setting up model and optimizer ++++")
model = models.get_model(experiment_config)
optimizer = optimization.get_optimizer(model, experiment_config)
model, optimizer = fabric.setup(model, optimizer)
lr_scheduler = optimization.get_learning_rate(optimizer, experiment_config)
fabric.print("++++ Setting up dataloaders ++++")
train_loader = data.get_dataloader("train", experiment_config)
test_loader = data.get_dataloader("test", experiment_config)
fabric.barrier()
# get and resume state if available
state = get_state(
model, optimizer, train_loader, lr_scheduler, fabric, experiment_config
)
# get losses
criterion = losses.get_loss(experiment_config, fabric)
models.print_trainable_parameters(model, fabric, experiment_config.training_config)
# get zero shot tokens
fabric.print("++++ Loading zero shot attribute tokens ++++")
zero_shot_attributes = fabric.to_device(
torch.load(experiment_config.zero_shot_tokens_fn)
)
# define metrics to track during training/validation
metrics_dict = metrics.get_metrics(fabric.device)
# sync before starting training
fabric.barrier()
fit(
train_loader,
test_loader,
model,
state,
zero_shot_attributes,
optimizer,
criterion,
lr_scheduler,
fabric,
metrics_dict,
experiment_config
)
Environment
- GPU models and configuration: 4 n1-standard-32 nodes in GCP each with 4 NVIDIA _L4 GPUs
- Any other relevant information:
I use the container image (pytorch) from GCP:
europe-docker.pkg.dev/vertex-ai/training/pytorch-gpu.1-13.py310:latest
Expected behavior
The training dataloader finish the epoch and the rest of the code continues its execution
Hey @miguelalba96. Thanks for reporting this issue.
Would you mind printing the length of each dataset, dataloader on each rank. Usually it hangs when a rank have more data than others. It shouldn't happen but I want to exclude this eventuallity.
Do you think you could share a tiny reproducible example with dummy data for me to debug ?
Best, T.C
when printing the ranks per node and on each: len(dataset), len(dataloader)` I get homogeneous number of samples on each:
Not sure how to reproduce this problem, I will check. I also noticed that when I load the state to resume training using the function I wrote above get_state, the dataloader doesn't seem to resume properly and iterates all over again through the data until it hangs 🤔:
Hey @miguelalba96, any chance you could create a reproducible Studio on https://lightning.ai/ that I can duplicate to investigate what's happening. Otherwise, it is hard for me to help you.
This issue has been automatically marked as stale because it has not had recent activity. It will be closed if no further activity occurs. Thank you for your contributions.