| | import os |
| | from functools import reduce |
| |
|
| | import torch |
| | import torch.nn as nn |
| |
|
| | from .mobilenetv2 import MobileNetV2 |
| |
|
| |
|
| | class BaseBackbone(nn.Module): |
| | """ Superclass of Replaceable Backbone Model for Semantic Estimation |
| | """ |
| |
|
| | def __init__(self, in_channels): |
| | super(BaseBackbone, self).__init__() |
| | self.in_channels = in_channels |
| |
|
| | self.model = None |
| | self.enc_channels = [] |
| |
|
| | def forward(self, x): |
| | raise NotImplementedError |
| |
|
| | def load_pretrained_ckpt(self): |
| | raise NotImplementedError |
| |
|
| |
|
| | class MobileNetV2Backbone(BaseBackbone): |
| | """ MobileNetV2 Backbone |
| | """ |
| |
|
| | def __init__(self, in_channels): |
| | super(MobileNetV2Backbone, self).__init__(in_channels) |
| |
|
| | self.model = MobileNetV2(self.in_channels, alpha=1.0, expansion=6, num_classes=None) |
| | self.enc_channels = [16, 24, 32, 96, 1280] |
| |
|
| | def forward(self, x): |
| | |
| | x = self.model.features[0](x) |
| | x = self.model.features[1](x) |
| | enc2x = x |
| |
|
| | |
| | x = self.model.features[2](x) |
| | x = self.model.features[3](x) |
| | enc4x = x |
| |
|
| | |
| | x = self.model.features[4](x) |
| | x = self.model.features[5](x) |
| | x = self.model.features[6](x) |
| | enc8x = x |
| |
|
| | |
| | x = self.model.features[7](x) |
| | x = self.model.features[8](x) |
| | x = self.model.features[9](x) |
| | x = self.model.features[10](x) |
| | x = self.model.features[11](x) |
| | x = self.model.features[12](x) |
| | x = self.model.features[13](x) |
| | enc16x = x |
| |
|
| | |
| | x = self.model.features[14](x) |
| | x = self.model.features[15](x) |
| | x = self.model.features[16](x) |
| | x = self.model.features[17](x) |
| | x = self.model.features[18](x) |
| | enc32x = x |
| | return [enc2x, enc4x, enc8x, enc16x, enc32x] |
| |
|
| | def load_pretrained_ckpt(self): |
| | |
| | ckpt_path = './pretrained/mobilenetv2_human_seg.ckpt' |
| | if not os.path.exists(ckpt_path): |
| | print('cannot find the pretrained mobilenetv2 backbone') |
| | exit() |
| | |
| | ckpt = torch.load(ckpt_path) |
| | self.model.load_state_dict(ckpt) |
| |
|