pytorch icon indicating copy to clipboard operation
pytorch copied to clipboard

SequentialLR cannot be used with ReduceLROnPlateau due to .step() not allowing for optional arguments

Open marcm-ml opened this issue 4 years ago • 9 comments

🐛 Bug

Currently, SequentialLR can only be used with schedulers that inherit from _LRScheduler or in other words adhere to the way .step() is called, i.e. without any arguments. This is not the case for the built in ReduceLROnPlateau as it takes in a metric value. Therefore, when calling SequentialLR.step() without arguments, ReduceLROnPlateau will raise an error once its milestone is reached. However, when calling SequentialLR.step(metric), SequentialLR will raise an error due to too many arguments.

Also see https://github.com/PyTorchLightning/pytorch-lightning/issues/10759

To Reproduce

scheduler1 = ConstantLR(self.opt, factor=0.1, total_iters=2)
scheduler2 = ReduceLROnPlateau (self.opt, gamma=0.9)
scheduler = SequentialLR(self.opt, schedulers=[scheduler1, scheduler2], milestones=[2])
for epoch in range(100):
   train(...)
   validate(...)
   scheduler.step()  # raises error after epoch>=2
   scheduler.step(metric)  # raises error immediately

Expected behavior

SequentialLR should at least be working with all internal pytorch LR schedulers, hence allow for a metric valued to be passed to ReduceLROnPlateau if it used (or ReduceLROnPlateau needs a rewrite). However, I think SequentialLR should allow for arbitrary arguments in step which allows for arbitrary schedulers (probably another issue in itself).

Either the user takes care of passing the right argument at the right time or there is a built in mechanic (if that is even possible).

cc @vincentqb @jbschlosser @albanD

marcm-ml avatar Nov 29 '21 09:11 marcm-ml

I would like to take a jab at this if no one hasn't already!

Before I open a PR, I'd like to discuss a few proposals.

  1. Convenient edge-casing

We can simply edge case ReduceLROnPlateau in SequentialLR.step() function. The step function should be able to take an optional argument metrics and will pass that argument to the scheduler if it is an instance of ReduceLROnPlateau.

# class SequentialLR
    def step(self, metrics=None):
        self.last_epoch += 1
        idx = bisect_right(self._milestones, self.last_epoch)
        scheduler = self._schedulers[idx]
        is_reduce_lr_on_plateau = isinstance(scheduler, ReduceLROnPlateau)
        is_new_scheduler = idx > 0 and self._milestones[idx - 1] == self.last_epoch
        if is_reduce_lr_on_plateau and is_new_scheduler:
            scheduler.step(metrics, 0)
        elif is_reduce_lr_on_plateau:
            scheduler.step(metrics)
        elif is_new_scheduler:
            scheduler.step(0)
        else:
            scheduler.step()
  1. More general edge-casing

A problem with the previous approach is that users might create custom schedulers that inherit from ReduceLROnPlateau in ways that may not be accounted for. One solution is to make the step function even more general so that the schedulers can do what they want to do with optional args and kwargs. Instead of checking if a scheduler is an instance of ReduceLROnPlateau, we can check if it is an instance of the _LRScheduler base class.

# class SequentialLR
    def step(self, *args, **kwargs):
        self.last_epoch += 1
        idx = bisect_right(self._milestones, self.last_epoch)
        scheduler = self._schedulers[idx]
        is_lr_scheduler = isinstance(scheduler, _LRScheduler)
        is_new_scheduler = idx > 0 and self._milestones[idx - 1] == self.last_epoch
        if is_lr_scheduler and is_new_scheduler:
            scheduler.step(0)
        elif is_lr_scheduler:
            scheduler.step()
        elif is_new_scheduler:
            scheduler.step(*args, **kwargs, epoch=0)
        else:
            scheduler.step(*args, **kwargs)
  1. Inheritance and modifying _LRScheduler

Given issues like https://github.com/pytorch/pytorch/issues/67760, https://github.com/pytorch/pytorch/issues/68332, and https://github.com/pytorch/pytorch/issues/68979, perhaps we should rewrite ReduceLROnPlateau to inherit from _LRScheduler and modify parts of the base class API. This direction would likely be the most time-consuming but arguably beneficial. I don't have code snippets to show right off the bat, as this would require more thinking.

What do you think? @jbschlosser @albanD

jaketae avatar Jan 07 '22 14:01 jaketae

Hi! Thanks for writing down a proposal. I'll take a look with @jbschlosser and come back to you by the end of the week!

albanD avatar Jan 11 '22 19:01 albanD

An alternative is to make the step function of _LRScheduler absorb and ignore all arguments (with *args, **kwargs in the signature) which allows to propagate the arguments of SequentialLR.step without worry. The children of _LRScheduler would be responsible for using or not these arguments. There are several ways to do that.

For instance, we could add a method (say get_get_lr) in _LRScheduler that would absorb all arguments and return self.get_lr(). step would call this function instead of get_lr.

def get_get_lr(self, *args, **kwargs) -> List[float]:
    return self.get_lr()

A scheduler with arguments would simply need to overwrite get_get_lr. For instance, this would allow to rewrite the ReduceLROnPlateau class as

def lrs(optimizer: Optimizer) -> List[float]:
    return [group['lr'] for group in optimizer.param_groups]

class ReduceLROnPlateau(_LRScheduler):
    r"""Reduce learning rate when a metric has stopped improving"""

    def __init__(
        self,
        optimizer: Optimizer,
        gamma: float = 0.5,  # <= 1
        patience: int = 7,
        cooldown: int = 1,
        threshold: float = 1e-2,
        mode: str = 'minimize',  # or 'maximize'
        min_lr: Union[float, List[float]] = 1e-6,
        last_epoch: int = -1,
        verbose: bool = False,
    ):
        self.gamma = gamma
        self.patience = patience
        self.cooldown = cooldown
        self.threshold = threshold
        self.mode = mode

        if type(min_lr) is float:
            min_lr = [min_lr] * len(optimizer.param_groups)
        self.min_lrs = min_lr

        self.best = float('-inf') if self.mode == 'maximize' else float('inf')
        self.last_best = last_epoch
        self.last_reduce = last_epoch

        super().__init__(optimizer, last_epoch, verbose)

    def get_get_lr(self, last: float, *args, **kwargs) -> List[float]:
        return self.get_lr(last)

    def get_lr(self, last: float) -> List[float]:
        if self.mode == 'maximize':
            accept = last >= self.best * (1 + self.threshold)
        else:  # mode == 'minimize'
            accept = last <= self.best * (1 - self.threshold)

        if accept:
            self.best = last
            self.last_best = self.last_epoch

            return lrs(self.optimizer)

        if self.last_epoch - max(self.last_best, self.last_reduce + self.cooldown) <= self.patience:
            return lrs(self.optimizer)

        self.last_reduce = self.last_epoch

        return [
            max(lr * self.gamma, min_lr)
            for lr, min_lr in zip(lrs(self.optimizer), self.min_lrs)
        ]

What do you think? @albanD, @jaketae

francois-rozet avatar Jan 24 '22 08:01 francois-rozet

Hey @francois-rozet, thanks for the reply and suggestion! It definitely makes sense. I'm personally in support of any solution that prioritizes implementation simplicity and BC.

Gently pinging @albanD to see if there are any updates from the internal PyTorch team on this!

jaketae avatar Jan 27 '22 23:01 jaketae

Any updates on this?

imarquart avatar Apr 18 '23 08:04 imarquart

No update I'm afraid. The Core team is pretty small and we don't have anyone working on LRScheduler.

My general stance of LRScheduler is that they would require a much bigger (most likely BC-breaking) cleanup to get them in a position where they are reliable. I would be happy to review a fix for this if they are small and BC but I'm not sure a complex change is worth it given the general state.

albanD avatar Apr 18 '23 13:04 albanD

What about the schedulers is currently not "reliable"?

One problem I encountered which led me to finding this existing issue is that a ReduceLROnPlateau cannot be used inside a ChainedScheduler, as it needs the metrics parameter to function

AngledLuffa avatar May 17 '23 14:05 AngledLuffa

Just to ping about this without opening a new issue - is there any possibility of a PR which allows for a scheduler such as ReduceLROnPlateau to be used in a Sequential... or ChainedScheduler?

AngledLuffa avatar Mar 03 '24 23:03 AngledLuffa

is there anyone that found a workaround for this as the above solutions could not help me it says that the metrics is not passed and always none . but when using reducelronplateu alone it works

abbas695 avatar May 18 '24 19:05 abbas695

While it is not very elegant, a quick solution is to override step function in SequentialLR class to include metrics for ROP scheduler like this:

class SequentialLR_ROP(SequentialLR):
    def __init__(self, optimizer, schedulers, milestones):
        super().__init__(optimizer, schedulers, milestones)

    def step(self, metrics: Optional[int] = ..., epoch: Optional[int] = ...) -> None:
        self.last_epoch += 1
        idx = bisect_right(self._milestones, self.last_epoch - 1)  # -1 to ensure previous scheduler stopped when it should
        scheduler = self._schedulers[idx]
        if idx > 0 and self._milestones[idx - 1] == self.last_epoch:
            if scheduler.__class__.__name__ == 'ReduceLROnPlateau':
                scheduler.step(metrics)
            else:
                scheduler.step(0)
        else:
            if scheduler.__class__.__name__ == 'ReduceLROnPlateau':
                scheduler.step(metrics)
            else:
                scheduler.step()

        self._last_lr = scheduler.get_last_lr()

And then use SequentialLR_ROP in the same way as regular SequentialLR. I tested it with schedulers=[LinearLR, ROP] and it seems to be working fine.

MatZar01 avatar Jul 02 '24 12:07 MatZar01

Unfortunately, the quick solution by @MatZar01 does not work anymore due to #121633.

annalena-k avatar Oct 15 '24 17:10 annalena-k