| import os |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
|
|
| class Encoder(nn.Module): |
| def __init__(self): |
| super(Encoder, self).__init__() |
|
|
| basemodel_name = 'tf_efficientnet_b5_ap' |
| print('Loading base model ()...'.format(basemodel_name), end='') |
| repo_path = os.path.join(os.path.dirname(__file__), 'efficientnet_repo') |
| basemodel = torch.hub.load(repo_path, basemodel_name, pretrained=False, source='local') |
| print('Done.') |
|
|
| |
| print('Removing last two layers (global_pool & classifier).') |
| basemodel.global_pool = nn.Identity() |
| basemodel.classifier = nn.Identity() |
|
|
| self.original_model = basemodel |
|
|
| def forward(self, x): |
| features = [x] |
| for k, v in self.original_model._modules.items(): |
| if (k == 'blocks'): |
| for ki, vi in v._modules.items(): |
| features.append(vi(features[-1])) |
| else: |
| features.append(v(features[-1])) |
| return features |
|
|
|
|
|
|