| from __future__ import annotations |
|
|
| import logging |
|
|
| import torch |
|
|
| from modules import ( |
| devices, |
| errors, |
| face_restoration, |
| face_restoration_utils, |
| modelloader, |
| shared, |
| ) |
|
|
| logger = logging.getLogger(__name__) |
|
|
| model_url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth' |
| model_download_name = 'codeformer-v0.1.0.pth' |
|
|
| |
| codeformer: face_restoration.FaceRestoration | None = None |
|
|
|
|
| class FaceRestorerCodeFormer(face_restoration_utils.CommonFaceRestoration): |
| def name(self): |
| return "CodeFormer" |
|
|
| def load_net(self) -> torch.Module: |
| for model_path in modelloader.load_models( |
| model_path=self.model_path, |
| model_url=model_url, |
| command_path=self.model_path, |
| download_name=model_download_name, |
| ext_filter=['.pth'], |
| ): |
| return modelloader.load_spandrel_model( |
| model_path, |
| device=devices.device_codeformer, |
| expected_architecture='CodeFormer', |
| ).model |
| raise ValueError("No codeformer model found") |
|
|
| def get_device(self): |
| return devices.device_codeformer |
|
|
| def restore(self, np_image, w: float | None = None): |
| if w is None: |
| w = getattr(shared.opts, "code_former_weight", 0.5) |
|
|
| def restore_face(cropped_face_t): |
| assert self.net is not None |
| return self.net(cropped_face_t, w=w, adain=True)[0] |
|
|
| return self.restore_with_helper(np_image, restore_face) |
|
|
|
|
| def setup_model(dirname: str) -> None: |
| global codeformer |
| try: |
| codeformer = FaceRestorerCodeFormer(dirname) |
| shared.face_restorers.append(codeformer) |
| except Exception: |
| errors.report("Error setting up CodeFormer", exc_info=True) |
|
|