| import os |
| import time |
| from dataclasses import dataclass |
|
|
| from configs.mode import FaceSwapMode |
| from configs.singleton import Singleton |
|
|
|
|
| @Singleton |
| @dataclass |
| class TrainConfig: |
| mode = FaceSwapMode.MANY_TO_MANY |
| source_name: str = "" |
|
|
| dataset_index: str = "/data/dataset/faceswap/full.pkl" |
| dataset_root: str = "/data/dataset/faceswap" |
|
|
| batch_size: int = 8 |
| num_threads: int = 8 |
| same_rate: float = 0.5 |
| lr: float = 5e-5 |
| grad_clip: float = 1000.0 |
|
|
| use_ddp: bool = True |
|
|
| mouth_mask: bool = True |
| eye_hm_loss: bool = False |
| mouth_hm_loss: bool = False |
|
|
| load_checkpoint = None |
|
|
| identity_extractor_config = { |
| "f_3d_checkpoint_path": "/checkpoints/Deep3DFaceRecon/epoch_20_new.pth", |
| "f_id_checkpoint_path": "/checkpoints/arcface/ms1mv3_arcface_r100_fp16_backbone.pth", |
| "bfm_folder": "/checkpoints/useful_ckpt/BFM", |
| "hrnet_path": "/checkpoints/useful_ckpt/face_98lmks/HR18-WFLW.pth", |
| } |
|
|
| visualize_interval: int = 100 |
| plot_interval: int = 100 |
| max_iters: int = 1000000 |
| checkpoint_interval: int = 40000 |
|
|
| exp_name: str = "exp_base" |
| log_basedir: str = "/data/logs/hififace/" |
| checkpoint_basedir = "/data/checkpoints/hififace" |
|
|
| def __post_init__(self): |
| time_stamp = int(time.time() * 1000) |
| self.log_dir = os.path.join(self.log_basedir, f"{self.exp_name}_{time_stamp}") |
| self.checkpoint_dir = os.path.join(self.checkpoint_basedir, f"{self.exp_name}_{time_stamp}") |
|
|
|
|
| if __name__ == "__main__": |
| tc = TrainConfig() |
| print(tc.log_dir) |
|
|