here is an example of a scheduler that I subclassed in pytorch.

class WarmUpAndCosineAnnealingLRScheduler(torch.optim.lr_scheduler._LRScheduler):
    def __init__(
        self,
        optimizer: Optimizer,
        warmup_steps: int,
        warmup_lr,
        lr_max,
        cosine_t,
        last_epoch: int = -1,
        verbose=False,
    ) -> None:

        assert warmup_steps >= 0, "warmup_steps >=0 required"
        self.warmup_steps = warmup_steps

        assert warmup_lr > 0, "warmup lr >0 required"
        self.warmup_lr = warmup_lr

        assert lr_max > warmup_lr, "lr_max > warmup_lr required"
        self.lr_max = lr_max

        assert cosine_t > 0, "cosine T >0 required"
        self.cosine_t = cosine_t

        super().__init__(optimizer, last_epoch)

    def get_lr(self) -> float:

        if self._step_count < self.warmup_steps:
            return [self.warmup_lr for _ in self.base_lrs]

        else:

            x = self._step_count - self.warmup_steps
            amplitude = self.lr_max - self.warmup_lr

            new_lr = self.warmup_lr + amplitude / 2 * (
                1 + math.cos(math.pi + x * 2 * math.pi / self.cosine_t)
            )

            return [new_lr for _ in self.base_lrs]

Categories: pytorch

0 Comments

Leave a Reply

Your email address will not be published.