| |
| """ |
| @author: Jianjie Luo |
| @contact: jianjieluo.sysu@gmail.com |
| """ |
| import torch |
| from uniperceiver.config import configurable |
| from .build import SOLVER_REGISTRY |
|
|
| @SOLVER_REGISTRY.register() |
| class AdamW(torch.optim.AdamW): |
| @configurable |
| def __init__( |
| self, |
| *, |
| params, |
| lr=1e-3, |
| betas=(0.9, 0.999), |
| eps=1e-8, |
| weight_decay=0.01, |
| amsgrad=False |
| ): |
| super(AdamW, self).__init__( |
| params, |
| lr, |
| betas, |
| eps, |
| weight_decay, |
| amsgrad |
| ) |
|
|
| @classmethod |
| def from_config(cls, cfg, params): |
| return { |
| "params": params, |
| "lr": cfg.SOLVER.BASE_LR, |
| "betas": cfg.SOLVER.BETAS, |
| "eps": cfg.SOLVER.EPS, |
| "weight_decay": cfg.SOLVER.WEIGHT_DECAY, |
| "amsgrad": cfg.SOLVER.AMSGRAD |
| } |
|
|