| import json |
| import torch |
| from torch import nn, Tensor |
| from loguru import logger |
| from pathlib import Path |
|
|
| from torchvision.transforms import ToTensor |
| from torchvision.transforms.v2 import CenterCrop, Compose, Normalize |
|
|
|
|
| import vits |
|
|
| def _clean_moco_state_dict(state_dict: dict[str, Tensor], linear_keyword: str) -> dict[str, Tensor]: |
| """ |
| Filters and renames keys from a MoCo state_dict. |
| |
| It selects keys from the 'base_encoder', removes the given linear layer keyword, |
| and strips the 'module.base_encoder.' prefix. |
| """ |
| for key in list(state_dict.keys()): |
| |
| if key.startswith('module.base_encoder') and not key.startswith(f'module.base_encoder.{linear_keyword}'): |
| |
| new_key = key[len("module.base_encoder."):] |
| state_dict[new_key] = state_dict[key] |
|
|
| |
| del state_dict[key] |
|
|
| return state_dict |
|
|
| def load_moco_encoder( |
| model: nn.Module, |
| weight_path: Path, |
| linear_keyword: str, |
| ) -> nn.Module: |
| """ |
| Loads pre-trained MoCo weights into a given model instance (ResNet, ViT, etc.). |
| |
| This function handles loading the checkpoint, cleaning the state dictionary keys, |
| and loading the weights into the model's backbone. It finishes by replacing |
| the model's linear head with an Identity layer to turn it into a feature extractor. |
| |
| Args: |
| model: An instantiated PyTorch model (e.g., from timm or a custom module). |
| weight_path: Path to the .pth or .pt MoCo checkpoint file. |
| linear_keyword: The name of the final linear layer to exclude (e.g., 'fc' or 'head'). |
| |
| Returns: |
| The same model, with pre-trained backbone weights and the head replaced |
| by nn.Identity(), ready for feature extraction. |
| """ |
| assert weight_path.exists(), f"Checkpoint not found at '{weight_path}'" |
| logger.info(f"=> Loading MoCo checkpoint from '{weight_path}'") |
|
|
| |
| checkpoint = torch.load(weight_path, map_location="cpu", weights_only=True) |
|
|
| |
| state_dict = checkpoint["state_dict"] |
|
|
| |
| cleaned_state_dict = _clean_moco_state_dict(state_dict, linear_keyword) |
|
|
| |
| msg = model.load_state_dict(cleaned_state_dict, strict=False) |
| logger.info(msg) |
| logger.info("=> Successfully loaded pre-trained model backbone.") |
|
|
| |
| if hasattr(model, linear_keyword): |
| setattr(model, linear_keyword, nn.Identity()) |
| logger.info(f"=> Model's '{linear_keyword}' layer replaced with nn.Identity for feature extraction.") |
|
|
| return model |
|
|
| def get_vit_feature_extractor(weight_path: Path, model_name: str = "vits8", img_size: int = 40) -> nn.Module: |
| """Creates a ViT feature extractor using the unified loader.""" |
| |
| vit_model = vits.__dict__[model_name](img_size=img_size, num_classes=0) |
|
|
| |
| feature_extractor = load_moco_encoder( |
| model=vit_model, |
| weight_path=weight_path, |
| linear_keyword='head' |
| ) |
| return feature_extractor |
|
|
|
|
| def prepare_transform( |
| stats_path, |
| size: int = 40, |
| ) -> Compose: |
| |
| with open(stats_path, "r") as f: |
| norm_dict = json.load(f) |
| mean = norm_dict["mean"] |
| std = norm_dict["std"] |
|
|
| |
| list_transform = [ |
| ToTensor(), |
| Normalize(mean=mean, std=std), |
| CenterCrop(size=size), |
| ] |
| transform = Compose(list_transform) |
| return transform |
|
|