miqa
xiaoqi-wang commited on
Commit
d4e19cd
·
verified ·
1 Parent(s): a1427df

Upload utils/lr_scheduler.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. utils/lr_scheduler.py +113 -0
utils/lr_scheduler.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from timm.scheduler.cosine_lr import CosineLRScheduler
3
+ from timm.scheduler.step_lr import StepLRScheduler
4
+ from timm.scheduler.scheduler import Scheduler
5
+
6
+
7
+ def build_scheduler(config, optimizer, n_iter_per_epoch):
8
+ num_steps = int(config.epochs * n_iter_per_epoch)
9
+ warmup_steps = int(config.warmup_epochs * n_iter_per_epoch)
10
+ decay_steps = int(config.decay_epochs * n_iter_per_epoch)
11
+
12
+ lr_scheduler = None
13
+ if config.lr_scheduler == "cosine":
14
+ lr_scheduler = CosineLRScheduler(
15
+ optimizer,
16
+ t_initial=num_steps,
17
+ lr_min=config.min_lr,
18
+ warmup_lr_init=config.warmup_lr,
19
+ warmup_t=warmup_steps,
20
+ cycle_limit=1,
21
+ t_in_epochs=False,
22
+ )
23
+ elif config.lr_scheduler == "linear":
24
+ lr_scheduler = LinearLRScheduler(
25
+ optimizer,
26
+ t_initial=num_steps,
27
+ lr_min_rate=0.01,
28
+ warmup_lr_init=config.warmup_lr,
29
+ warmup_t=warmup_steps,
30
+ t_in_epochs=False,
31
+ )
32
+ elif config.lr_scheduler == "step":
33
+ lr_scheduler = StepLRScheduler(
34
+ optimizer,
35
+ decay_t=decay_steps,
36
+ decay_rate=config.decay_rate,
37
+ warmup_lr_init=config.warmup_lr,
38
+ warmup_t=warmup_steps,
39
+ t_in_epochs=False,
40
+ )
41
+
42
+ return lr_scheduler
43
+
44
+ # lr_scheduler.step_update()
45
+ # def step_update(self, num_updates: int, metric: float = None):
46
+ # self.metric = metric
47
+ # values = self._get_values(num_updates, on_epoch=False)
48
+ # if values is not None:
49
+ # values = self._add_noise(values, num_updates)
50
+ # self.update_groups(values)
51
+
52
+
53
+ class LinearLRScheduler(Scheduler):
54
+ def __init__(
55
+ self,
56
+ optimizer: torch.optim.Optimizer,
57
+ t_initial: int,
58
+ lr_min_rate: float,
59
+ warmup_t=0,
60
+ warmup_lr_init=0.0,
61
+ t_in_epochs=True,
62
+ noise_range_t=None,
63
+ noise_pct=0.67,
64
+ noise_std=1.0,
65
+ noise_seed=42,
66
+ initialize=True,
67
+ ) -> None:
68
+ super().__init__(
69
+ optimizer,
70
+ param_group_field="lr",
71
+ noise_range_t=noise_range_t,
72
+ noise_pct=noise_pct,
73
+ noise_std=noise_std,
74
+ noise_seed=noise_seed,
75
+ initialize=initialize,
76
+ )
77
+
78
+ self.t_initial = t_initial
79
+ self.lr_min_rate = lr_min_rate
80
+ self.warmup_t = warmup_t
81
+ self.warmup_lr_init = warmup_lr_init
82
+ self.t_in_epochs = t_in_epochs
83
+ if self.warmup_t:
84
+ self.warmup_steps = [
85
+ (v - warmup_lr_init) / self.warmup_t for v in self.base_values
86
+ ]
87
+ super().update_groups(self.warmup_lr_init)
88
+ else:
89
+ self.warmup_steps = [1 for _ in self.base_values]
90
+
91
+ def _get_lr(self, t):
92
+ if t < self.warmup_t:
93
+ lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps]
94
+ else:
95
+ t = t - self.warmup_t
96
+ total_t = self.t_initial - self.warmup_t
97
+ lrs = [
98
+ v - ((v - v * self.lr_min_rate) * (t / total_t))
99
+ for v in self.base_values
100
+ ]
101
+ return lrs
102
+
103
+ def get_epoch_values(self, epoch: int):
104
+ if self.t_in_epochs:
105
+ return self._get_lr(epoch)
106
+ else:
107
+ return None
108
+
109
+ def get_update_values(self, num_updates: int):
110
+ if not self.t_in_epochs:
111
+ return self._get_lr(num_updates)
112
+ else:
113
+ return None