| import torch |
| import torch.nn as nn |
| import copy |
| from functools import partial |
| from .dasheng import LayerScale, Attention, Mlp |
|
|
|
|
| class Decoder_Block(nn.Module): |
| def __init__( |
| self, |
| dim, |
| num_heads, |
| mlp_ratio=4., |
| qkv_bias=False, |
| drop=0., |
| attn_drop=0., |
| init_values=None, |
| act_layer=nn.GELU, |
| norm_layer=nn.LayerNorm, |
| attention_type='Attention', |
| fusion='adaln', |
| ): |
| super().__init__() |
| self.norm1 = norm_layer(dim) |
| self.attn = Attention(dim, |
| num_heads=num_heads, |
| qkv_bias=qkv_bias, |
| attn_drop=attn_drop, |
| proj_drop=drop) |
| self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() |
|
|
| self.norm2 = norm_layer(dim) |
| self.mlp = Mlp(in_features=dim, |
| hidden_features=int(dim * mlp_ratio), |
| act_layer=act_layer, |
| drop=drop) |
| self.ls2 = LayerScale( |
| dim, init_values=init_values) if init_values else nn.Identity() |
|
|
| self.fusion = fusion |
| if fusion == 'adaln': |
| self.adaln = nn.Linear(dim, 6 * dim, bias=True) |
|
|
| def forward(self, x, c=None): |
| B, T, C = x.shape |
|
|
| if self.fusion == 'adaln': |
| ada = self.adaln(c) |
| (scale_msa, gate_msa, shift_msa, |
| scale_mlp, gate_mlp, shift_mlp) = ada.reshape(B, 6, -1).chunk(6, dim=1) |
| |
| x_norm = self.norm1(x) * (1 + scale_msa) + shift_msa |
| tanh_gate_msa = torch.tanh(1 - gate_msa) |
| x = x + tanh_gate_msa * self.ls1(self.attn(x_norm)) |
| |
| x_norm = self.norm2(x) * (1 + scale_mlp) + shift_mlp |
| tanh_gate_mlp = torch.tanh(1 - gate_mlp) |
| x = x + tanh_gate_mlp * self.ls2(self.mlp(x_norm)) |
| else: |
| x = x + self.ls1(self.attn(self.norm1(x))) |
| x = x + self.ls2(self.mlp(self.norm2(x))) |
| return x |
|
|
|
|
| class Decoder(nn.Module): |
| def __init__( |
| self, |
| embed_dim: int = 768, |
| depth: int = 2, |
| num_heads=8, |
| mlp_ratio=4., |
| qkv_bias=True, |
| drop_rate=0., |
| attn_drop_rate=0., |
| cls_dim: int = 512, |
| fusion: str = 'adaln', |
| **kwargs |
| ): |
| super().__init__() |
|
|
| norm_layer = partial(nn.LayerNorm, eps=1e-6) |
| act_layer = nn.GELU |
| init_values = None |
|
|
| block_function = Decoder_Block |
| self.blocks = nn.ModuleList([ |
| block_function( |
| dim=embed_dim, |
| num_heads=num_heads, |
| mlp_ratio=mlp_ratio, |
| qkv_bias=qkv_bias, |
| init_values=init_values, |
| drop=drop_rate, |
| attn_drop=attn_drop_rate, |
| norm_layer=norm_layer, |
| act_layer=act_layer, |
| attention_type="Attention", |
| fusion=fusion, |
| ) for _ in range(depth) |
| ]) |
|
|
| self.fusion = fusion |
| cls_out = embed_dim |
|
|
| self.cls_embed = nn.Sequential( |
| nn.Linear(cls_dim, embed_dim, bias=True), |
| nn.SiLU(), |
| nn.Linear(embed_dim, cls_out, bias=True),) |
|
|
| self.sed_head = nn.Linear(embed_dim, 1, bias=True) |
| self.norm = norm_layer(embed_dim) |
| self.apply(self.init_weights) |
| |
|
|
| def init_weights(self, module): |
| if isinstance(module, nn.Linear): |
| torch.nn.init.xavier_uniform_(module.weight) |
| if module.bias is not None: |
| nn.init.zeros_(module.bias) |
| elif isinstance(module, nn.LayerNorm): |
| nn.init.constant_(module.bias, 0) |
| nn.init.constant_(module.weight, 1.0) |
|
|
| if self.fusion == 'adaln': |
| for block in self.blocks: |
| nn.init.constant_(block.adaln.weight, 0) |
| nn.init.constant_(block.adaln.bias, 0) |
|
|
| def forward(self, x, cls): |
| B, L, C = x.shape |
| _, N, D = cls.shape |
| |
| x = x.unsqueeze(1).expand(-1, N, -1, -1) |
| |
| x = x.reshape(B * N, L, C) |
| cls = cls.reshape(B * N, D) |
|
|
| cls = self.cls_embed(cls) |
|
|
| shift = 0 |
| if self.fusion == 'adaln': |
| pass |
| elif self.fusion == 'token': |
| cls = cls.unsqueeze(1) |
| x = torch.cat([cls, x], dim=1) |
| shift = 1 |
| else: |
| raise NotImplementedError("unknown fusion") |
|
|
| for block in self.blocks: |
| x = block(x, cls) |
|
|
| x = x[:, shift:] |
|
|
| x = self.norm(x) |
|
|
| strong = self.sed_head(x) |
| return strong.transpose(1, 2) |
|
|
|
|
| class TSED_Wrapper(nn.Module): |
| def __init__( |
| self, |
| encoder, |
| decoder, |
| ft_blocks=[11, 12], |
| frozen_encoder=True |
| ): |
| super().__init__() |
|
|
| self.encoder = encoder |
| self.decoder = decoder |
|
|
| print("Loading Dasheng weights for decoders...") |
| for i, blk_idx in enumerate(ft_blocks): |
| decoder_block = self.decoder.blocks[i] |
| encoder_block = self.encoder.blocks[blk_idx] |
| state_dict = copy.deepcopy(encoder_block.state_dict()) |
| missing, unexpected = decoder_block.load_state_dict(state_dict, strict=False) |
| if missing or unexpected: |
| print(f"Block {blk_idx}:") |
| if missing: |
| print(f"✅ Expected missing keys: {missing}") |
| if unexpected: |
| print(f" Unexpected keys: {unexpected}") |
| |
| self.decoder.norm.load_state_dict(copy.deepcopy(self.encoder.norm.state_dict())) |
|
|
| |
| for blk_idx in sorted(ft_blocks, reverse=True): |
| |
| del self.encoder.blocks[blk_idx] |
| |
| del self.encoder.norm |
|
|
| self.frozen_encoder = frozen_encoder |
| if frozen_encoder: |
| for param in self.encoder.parameters(): |
| param.requires_grad = False |
|
|
| def forward_to_spec(self, x): |
| return self.encoder.forward_to_spec(x) |
|
|
| def forward_encoder(self, x): |
| if self.frozen_encoder: |
| with torch.no_grad(): |
| x = self.encoder(x) |
| else: |
| x = self.encoder(x) |
| return x |
|
|
| def forward(self, x, cls): |
| x = self.forward_encoder(x) |
| pred = self.decoder(x, cls) |
| return pred |
|
|