here is an example of a scheduler that I subclassed in pytorch.
class WarmUpAndCosineAnnealingLRScheduler(torch.optim.lr_scheduler._LRScheduler): def __init__( self, optimizer: Optimizer, warmup_steps: int, warmup_lr, lr_max, cosine_t, last_epoch: int = -1, verbose=False, ) -> None: assert warmup_steps >= 0, "warmup_steps >=0 required" self.warmup_steps = warmup_steps assert warmup_lr > 0, "warmup lr >0 required" self.warmup_lr = warmup_lr assert lr_max > warmup_lr, "lr_max > warmup_lr required" self.lr_max = lr_max assert cosine_t > 0, "cosine T >0 required" self.cosine_t = cosine_t super().__init__(optimizer, last_epoch) def get_lr(self) -> float: if self._step_count < self.warmup_steps: return [self.warmup_lr for _ in self.base_lrs] else: x = self._step_count - self.warmup_steps amplitude = self.lr_max - self.warmup_lr new_lr = self.warmup_lr + amplitude / 2 * ( 1 + math.cos(math.pi + x * 2 * math.pi / self.cosine_t) ) return [new_lr for _ in self.base_lrs]
0 Comments