Ask-Anything icon indicating copy to clipboard operation
Ask-Anything copied to clipboard

Using Gradient Accumulation

Open bexxnaz opened this issue 1 year ago • 3 comments

Hello Thanks for your great work. I need to use gradient accumulation on batches due to RAM constraints. The training loop involves iterating over two modalities. I am concerned about the implications of using gradient accumulation in this scenario. Is it possible and recommended to use gradient accumulation with multiple modalities in an iterator?

 with torch.cuda.amp.autocast(enabled=config.fp16):
            loss_dict = model(image, text)
            loss = sum(loss_dict.values()) / config.accumulate_grad_batches  
        scaler.scale(loss).backward()
        accumulated_batches += 1
        if accumulated_batches % config.accumulate_grad_batches == 0:
            if config.optimizer.max_grad_norm > 0:
                scaler.unscale_(optimizer)
                torch.nn.utils.clip_grad_norm_(model.parameters(), config.optimizer.max_grad_norm)
            
            scaler.step(optimizer)
            scaler.update()
            optimizer.zero_grad()  # Reset gradients only after optimizer step
            scheduler.step()  # Step the scheduler as per your original strategy
            accumulated_batches = 0  

bexxnaz avatar May 14 '24 10:05 bexxnaz

Good question! Can you try to use a small batch? From my previous experience, the results are similar.

Besides, if you want to use gradient accumulation, I think the multiple modalities in an iterator will be a good baseline, but some papers argue that single modalities in an iterator will be better. If you want to realize it, a simple strategy is to split the input data manually.

Andy1621 avatar May 20 '24 03:05 Andy1621

Thank you for your response. I have another question regarding the extra_num_query_tokens. Specifically, I'm interested in understanding if you've tested the scenario where this parameter is set to 0. How does this compression of visual tokens affect performance?

bexxnaz avatar May 27 '24 08:05 bexxnaz

There is an ablation in our paper. And 0 extra query lead to poorer performance on MVBench image

Andy1621 avatar May 27 '24 08:05 Andy1621

Hi, we will close this issue.

Feel free to contact us if you have other questions.

yinanhe avatar Oct 14 '24 03:10 yinanhe