GradCache icon indicating copy to clipboard operation
GradCache copied to clipboard

traning speed is very slow

Open liuweie opened this issue 1 year ago • 6 comments

Hi, I use grad_cache to train my model, but it seems very slow, I want to konw is this normal? Does using grad cache generally affect the training speed?

liuweie avatar Jun 25 '24 11:06 liuweie

It'd be hard to diagnose this based on qualitative descriptions. Maybe you can share some of you setups, observed&reference throughput/latency, etc.

luyug avatar Jun 25 '24 17:06 luyug

It'd be hard to diagnose this based on qualitative descriptions. Maybe you can share some of you setups, observed&reference throughput/latency, etc.

thanks, I am using huggingface Trainer to train a Qwen7B model, here is my setups and corespodding code:

① compute loss function, which override the original Trainer compute_loss function:

 def compute_loss(
        self,
        model: nn.Module,
        inputs: Dict[str, Union[torch.Tensor, Any]],
        return_outputs: bool = False,
        ) -> Union[Tuple[torch.Tensor, torch.Tensor], torch.Tensor]:

        
        features, labels = inputs
        query_input = features[0]
        pos_doc_input = features[1]

        loss_fn = DistributedContrastiveLoss(temperature=20, n_hard_negatives=0)
        gc = GradCache(
            models=[self.model, self.model],
            chunk_sizes=1,
            loss_fn=loss_fn,
            get_rep_fn=None,
            fp16=False
            )

        detached_loss = gc(query_input, pos_doc_input).requires_grad_()
        return detached_loss`

As you can see, I set chunksize=1,and I also tried to set chunksize=4\16\64, and batchsize in trainer setting is 256; the device is A800(80G) with 2 GPUs

②DistributedContrastiveLoss function, similar with DistributedContrastiveLoss in loss.py of grad_cache pakage, only added a temperature parameter to scale the score;

class DistributedContrastiveLoss(SimpleContrastiveLoss):
    def __init__(self, temperature, n_hard_negatives: int = 0):
        assert dist.is_initialized(), "Distributed training has not been properly initialized."

        super().__init__(temperature=temperature,n_hard_negatives=n_hard_negatives)
        self.word_size = dist.get_world_size()
        self.rank = dist.get_rank()

    def __call__(self, x: Tensor, y: Tensor, **kwargs):
        dist_x = self.gather_tensor(x)
        dist_y = self.gather_tensor(y)

        return super().__call__(dist_x, dist_y, **kwargs)

    def gather_tensor(self, t):
        gathered = [torch.empty_like(t) for _ in range(self.word_size)]
        dist.all_gather(gathered, t)
        gathered[self.rank] = t

        return torch.cat(gathered, dim=0)

which SimpleContrastiveLoss like this:

class SimpleContrastiveLoss:
    def __init__(self, temperature, n_hard_negatives: int = 0):
        self.target_per_qry = n_hard_negatives + 1
        self.temperature = temperature

    def __call__(self, x: Tensor, y: Tensor, target: Tensor = None, reduction: str = 'mean'):
        if target is None:
            assert x.size(0) * self.target_per_qry == y.size(0)
            target = torch.arange(0, x.size(0) * self.target_per_qry, self.target_per_qry, device=x.device)

        logits = torch.matmul(x, y.transpose(0, 1))
        logits = logits*self.temperature
        return F.cross_entropy(logits, target, reduction=reduction)

finally, code can run success, but grad updata is extremely slow

liuweie avatar Jun 26 '24 03:06 liuweie

can you share the observed runtime and reference runtime

meanwhile, one thing to note is that huggingface Trainer can trigger features like deepspeed zero, which came after gradcache release and therefore may not be smoothly supported.

luyug avatar Jun 26 '24 16:06 luyug

@luyug I think I have figure this prolem out, thanks. but during my experiment , I found that the loss is very difficult to converge, here is my log:

{'loss': 4.8621, 'grad_norm': 3.703125, 'learning_rate': 0.0001978902953586498, 'epoch': 0.08}
{'loss': 3.3299, 'grad_norm': 1.140625, 'learning_rate': 0.00019261603375527427, 'epoch': 0.16}
{'loss': 3.0912, 'grad_norm': 3.578125, 'learning_rate': 0.00018734177215189873, 'epoch': 0.23}
{'loss': 2.6461, 'grad_norm': 1.0234375, 'learning_rate': 0.00018206751054852322, 'epoch': 0.31}
{'loss': 1.8239, 'grad_norm': 0.7578125, 'learning_rate': 0.00017679324894514769, 'epoch': 0.39}
{'loss': 2.4623, 'grad_norm': 4.4375, 'learning_rate': 0.00017151898734177218, 'epoch': 0.47}
{'loss': 2.1719, 'grad_norm': 1.1640625, 'learning_rate': 0.00016624472573839661, 'epoch': 0.55}
{'loss': 2.6063, 'grad_norm': 16.125, 'learning_rate': 0.0001609704641350211, 'epoch': 0.62}
{'loss': 2.2289, 'grad_norm': 6.40625, 'learning_rate': 0.0001556962025316456, 'epoch': 0.7}
{'loss': 2.1505, 'grad_norm': 1.5078125, 'learning_rate': 0.00015042194092827003, 'epoch': 0.78}
{'loss': 2.2342, 'grad_norm': 3.375, 'learning_rate': 0.00014514767932489453, 'epoch': 0.86}
{'loss': 1.7903, 'grad_norm': 2.125, 'learning_rate': 0.000139873417721519, 'epoch': 0.93}
{'loss': 2.5553, 'grad_norm': 2.21875, 'learning_rate': 0.00013459915611814345, 'epoch': 1.01}
{'loss': 1.8375, 'grad_norm': 0.95703125, 'learning_rate': 0.00012932489451476795, 'epoch': 1.09}
{'loss': 2.379, 'grad_norm': 0.71875, 'learning_rate': 0.0001240506329113924, 'epoch': 1.17}
{'loss': 2.5603, 'grad_norm': 3.203125, 'learning_rate': 0.00011877637130801689, 'epoch': 1.25}
....
....
...
{'loss': 1.9158, 'grad_norm': 2.84375, 'learning_rate': 3.966244725738397e-05, 'epoch': 2.41}
{'loss': 2.2063, 'grad_norm': 3.625, 'learning_rate': 3.438818565400844e-05, 'epoch': 2.49}
{'loss': 2.1187, 'grad_norm': 1.46875, 'learning_rate': 2.9113924050632914e-05, 'epoch': 2.57}
{'loss': 2.055, 'grad_norm': 0.94140625, 'learning_rate': 2.3839662447257385e-05, 'epoch': 2.65}

As you can see, the loss is alway around 2, and if I don't use grad cache , loss can converge to 0.2

liuweie avatar Jul 01 '24 03:07 liuweie

@luyug I think I have figure this prolem out, thanks. but during my experiment , I found that the loss is very difficult to converge, here is my log:

{'loss': 4.8621, 'grad_norm': 3.703125, 'learning_rate': 0.0001978902953586498, 'epoch': 0.08}
{'loss': 3.3299, 'grad_norm': 1.140625, 'learning_rate': 0.00019261603375527427, 'epoch': 0.16}
{'loss': 3.0912, 'grad_norm': 3.578125, 'learning_rate': 0.00018734177215189873, 'epoch': 0.23}
{'loss': 2.6461, 'grad_norm': 1.0234375, 'learning_rate': 0.00018206751054852322, 'epoch': 0.31}
{'loss': 1.8239, 'grad_norm': 0.7578125, 'learning_rate': 0.00017679324894514769, 'epoch': 0.39}
{'loss': 2.4623, 'grad_norm': 4.4375, 'learning_rate': 0.00017151898734177218, 'epoch': 0.47}
{'loss': 2.1719, 'grad_norm': 1.1640625, 'learning_rate': 0.00016624472573839661, 'epoch': 0.55}
{'loss': 2.6063, 'grad_norm': 16.125, 'learning_rate': 0.0001609704641350211, 'epoch': 0.62}
{'loss': 2.2289, 'grad_norm': 6.40625, 'learning_rate': 0.0001556962025316456, 'epoch': 0.7}
{'loss': 2.1505, 'grad_norm': 1.5078125, 'learning_rate': 0.00015042194092827003, 'epoch': 0.78}
{'loss': 2.2342, 'grad_norm': 3.375, 'learning_rate': 0.00014514767932489453, 'epoch': 0.86}
{'loss': 1.7903, 'grad_norm': 2.125, 'learning_rate': 0.000139873417721519, 'epoch': 0.93}
{'loss': 2.5553, 'grad_norm': 2.21875, 'learning_rate': 0.00013459915611814345, 'epoch': 1.01}
{'loss': 1.8375, 'grad_norm': 0.95703125, 'learning_rate': 0.00012932489451476795, 'epoch': 1.09}
{'loss': 2.379, 'grad_norm': 0.71875, 'learning_rate': 0.0001240506329113924, 'epoch': 1.17}
{'loss': 2.5603, 'grad_norm': 3.203125, 'learning_rate': 0.00011877637130801689, 'epoch': 1.25}
....
....
...
{'loss': 1.9158, 'grad_norm': 2.84375, 'learning_rate': 3.966244725738397e-05, 'epoch': 2.41}
{'loss': 2.2063, 'grad_norm': 3.625, 'learning_rate': 3.438818565400844e-05, 'epoch': 2.49}
{'loss': 2.1187, 'grad_norm': 1.46875, 'learning_rate': 2.9113924050632914e-05, 'epoch': 2.57}
{'loss': 2.055, 'grad_norm': 0.94140625, 'learning_rate': 2.3839662447257385e-05, 'epoch': 2.65}

As you can see, the loss is alway around 2, and if I don't use grad cache , loss can converge to 0.2

Hi, I am having similar issue. The loss does not converge after using gradcache. Did you solve this issue?

lfb-1 avatar Sep 25 '24 03:09 lfb-1

@luyug I think I have figure this prolem out, thanks. but during my experiment , I found that the loss is very difficult to converge, here is my log:

{'loss': 4.8621, 'grad_norm': 3.703125, 'learning_rate': 0.0001978902953586498, 'epoch': 0.08}
{'loss': 3.3299, 'grad_norm': 1.140625, 'learning_rate': 0.00019261603375527427, 'epoch': 0.16}
{'loss': 3.0912, 'grad_norm': 3.578125, 'learning_rate': 0.00018734177215189873, 'epoch': 0.23}
{'loss': 2.6461, 'grad_norm': 1.0234375, 'learning_rate': 0.00018206751054852322, 'epoch': 0.31}
{'loss': 1.8239, 'grad_norm': 0.7578125, 'learning_rate': 0.00017679324894514769, 'epoch': 0.39}
{'loss': 2.4623, 'grad_norm': 4.4375, 'learning_rate': 0.00017151898734177218, 'epoch': 0.47}
{'loss': 2.1719, 'grad_norm': 1.1640625, 'learning_rate': 0.00016624472573839661, 'epoch': 0.55}
{'loss': 2.6063, 'grad_norm': 16.125, 'learning_rate': 0.0001609704641350211, 'epoch': 0.62}
{'loss': 2.2289, 'grad_norm': 6.40625, 'learning_rate': 0.0001556962025316456, 'epoch': 0.7}
{'loss': 2.1505, 'grad_norm': 1.5078125, 'learning_rate': 0.00015042194092827003, 'epoch': 0.78}
{'loss': 2.2342, 'grad_norm': 3.375, 'learning_rate': 0.00014514767932489453, 'epoch': 0.86}
{'loss': 1.7903, 'grad_norm': 2.125, 'learning_rate': 0.000139873417721519, 'epoch': 0.93}
{'loss': 2.5553, 'grad_norm': 2.21875, 'learning_rate': 0.00013459915611814345, 'epoch': 1.01}
{'loss': 1.8375, 'grad_norm': 0.95703125, 'learning_rate': 0.00012932489451476795, 'epoch': 1.09}
{'loss': 2.379, 'grad_norm': 0.71875, 'learning_rate': 0.0001240506329113924, 'epoch': 1.17}
{'loss': 2.5603, 'grad_norm': 3.203125, 'learning_rate': 0.00011877637130801689, 'epoch': 1.25}
....
....
...
{'loss': 1.9158, 'grad_norm': 2.84375, 'learning_rate': 3.966244725738397e-05, 'epoch': 2.41}
{'loss': 2.2063, 'grad_norm': 3.625, 'learning_rate': 3.438818565400844e-05, 'epoch': 2.49}
{'loss': 2.1187, 'grad_norm': 1.46875, 'learning_rate': 2.9113924050632914e-05, 'epoch': 2.57}
{'loss': 2.055, 'grad_norm': 0.94140625, 'learning_rate': 2.3839662447257385e-05, 'epoch': 2.65}

As you can see, the loss is alway around 2, and if I don't use grad cache , loss can converge to 0.2

Hi @liuweie, I am facing the same issue: the speed is very slow. I also observed that my GPU memory was only a few GB in use, even though I had increased the batch size. I want to ask how you handled this problem. Thank you so much for your help!

liyongkang123 avatar Nov 25 '24 23:11 liyongkang123