| | import pdb |
| | import scipy |
| | import numpy as np |
| |
|
| | scipy.inf = np.inf |
| |
|
| | import torch |
| | import torch.nn as nn |
| | import numpy as np |
| | import torch.nn.functional as F |
| | from dataset.custom_types import MsaInfo |
| | from msaf.eval import compute_results |
| | from postprocessing.functional import postprocess_functional_structure |
| | from x_transformers import Encoder |
| | import bisect |
| |
|
| |
|
| | class Head(nn.Module): |
| | def __init__(self, input_dim, output_dim, hidden_dims=None, activation="silu"): |
| | super().__init__() |
| | hidden_dims = hidden_dims or [] |
| | act_layers = {"relu": nn.ReLU, "silu": nn.SiLU, "gelu": nn.GELU} |
| | act_layer = act_layers.get(activation.lower()) |
| | if not act_layer: |
| | raise ValueError(f"Unsupported activation: {activation}") |
| |
|
| | dims = [input_dim] + hidden_dims + [output_dim] |
| | layers = [] |
| | for i in range(len(dims) - 1): |
| | layers.append(nn.Linear(dims[i], dims[i + 1])) |
| | if i < len(dims) - 2: |
| | layers.append(act_layer()) |
| | self.net = nn.Sequential(*layers) |
| |
|
| | def reset_parameters(self, confidence): |
| | bias_value = -torch.log(torch.tensor((1 - confidence) / confidence)) |
| | self.net[-1].bias.data.fill_(bias_value.item()) |
| |
|
| | def forward(self, x): |
| | batch, T, C = x.shape |
| | x = x.reshape(-1, C) |
| | x = self.net(x) |
| | return x.reshape(batch, T, -1) |
| |
|
| |
|
| | class WrapedTransformerEncoder(nn.Module): |
| | def __init__( |
| | self, input_dim, transformer_input_dim, num_layers=1, nhead=8, dropout=0.1 |
| | ): |
| | super().__init__() |
| | self.input_dim = input_dim |
| | self.transformer_input_dim = transformer_input_dim |
| |
|
| | if input_dim != transformer_input_dim: |
| | self.input_proj = nn.Sequential( |
| | nn.Linear(input_dim, transformer_input_dim), |
| | nn.LayerNorm(transformer_input_dim), |
| | nn.GELU(), |
| | nn.Dropout(dropout * 0.5), |
| | nn.Linear(transformer_input_dim, transformer_input_dim), |
| | ) |
| | else: |
| | self.input_proj = nn.Identity() |
| |
|
| | self.transformer = Encoder( |
| | dim=transformer_input_dim, |
| | depth=num_layers, |
| | heads=nhead, |
| | layer_dropout=dropout, |
| | attn_dropout=dropout, |
| | ff_dropout=dropout, |
| | attn_flash=True, |
| | rotary_pos_emb=True, |
| | ) |
| |
|
| | def forward(self, x, src_key_padding_mask=None): |
| | """ |
| | The input src_key_padding_mask is a B x T boolean mask, where True indicates masked positions. |
| | However, in x-transformers, False indicates masked positions. |
| | Therefore, it needs to be converted so that False represents masked positions. |
| | """ |
| | x = self.input_proj(x) |
| | mask = ( |
| | ~torch.tensor(src_key_padding_mask, dtype=torch.bool, device=x.device) |
| | if src_key_padding_mask is not None |
| | else None |
| | ) |
| | return self.transformer(x, mask=mask) |
| |
|
| |
|
| | def prefix_dict(d, prefix: str): |
| | if prefix: |
| | return d |
| | return {prefix + key: value for key, value in d.items()} |
| |
|
| |
|
| | class TimeDownsample(nn.Module): |
| | def __init__( |
| | self, dim_in, dim_out=None, kernel_size=5, stride=5, padding=0, dropout=0.1 |
| | ): |
| | super().__init__() |
| | self.dim_out = dim_out or dim_in |
| | assert self.dim_out % 2 == 0 |
| |
|
| | self.depthwise_conv = nn.Conv1d( |
| | in_channels=dim_in, |
| | out_channels=dim_in, |
| | kernel_size=kernel_size, |
| | stride=stride, |
| | padding=padding, |
| | groups=dim_in, |
| | bias=False, |
| | ) |
| | self.pointwise_conv = nn.Conv1d( |
| | in_channels=dim_in, |
| | out_channels=self.dim_out, |
| | kernel_size=1, |
| | bias=False, |
| | ) |
| | self.pool = nn.AvgPool1d(kernel_size, stride, padding=padding) |
| | self.norm1 = nn.LayerNorm(self.dim_out) |
| | self.act1 = nn.GELU() |
| | self.dropout1 = nn.Dropout(dropout) |
| |
|
| | if dim_in != self.dim_out: |
| | self.residual_conv = nn.Conv1d( |
| | dim_in, self.dim_out, kernel_size=1, bias=False |
| | ) |
| | else: |
| | self.residual_conv = None |
| |
|
| | def forward(self, x): |
| | residual = x |
| | |
| | x_c = x.transpose(1, 2) |
| | x_c = self.depthwise_conv(x_c) |
| | x_c = self.pointwise_conv(x_c) |
| |
|
| | |
| | res = self.pool(residual.transpose(1, 2)) |
| | if self.residual_conv: |
| | res = self.residual_conv(res) |
| | x_c = x_c + res |
| | x_c = x_c.transpose(1, 2) |
| | x_c = self.norm1(x_c) |
| | x_c = self.act1(x_c) |
| | x_c = self.dropout1(x_c) |
| | return x_c |
| |
|
| |
|
| | class AddFuse(nn.Module): |
| | def __init__(self): |
| | super(AddFuse, self).__init__() |
| |
|
| | def forward(self, x, cond): |
| | return x + cond |
| |
|
| |
|
| | class TVLoss1D(nn.Module): |
| | def __init__( |
| | self, beta=1.0, lambda_tv=0.4, boundary_threshold=0.01, reduction_weight=0.1 |
| | ): |
| | """ |
| | Args: |
| | beta: Exponential parameter for TV loss (recommended 0.5~1.0) |
| | lambda_tv: Overall weight for TV loss |
| | boundary_threshold: Label threshold to determine if a region is a "boundary area" (e.g., 0.01) |
| | reduction_weight: Scaling factor for TV penalty within boundary regions (e.g., 0.1, meaning only 10% penalty) |
| | """ |
| | super().__init__() |
| | self.beta = beta |
| | self.lambda_tv = lambda_tv |
| | self.boundary_threshold = boundary_threshold |
| | self.reduction_weight = reduction_weight |
| |
|
| | def forward(self, pred, target=None): |
| | """ |
| | Args: |
| | pred: (B, T) or (B, T, 1), float boundary scores output by the model |
| | target: (B, T) or (B, T, 1), ground truth labels (optional, used for spatial weighting if provided) |
| | |
| | Returns: |
| | scalar: weighted TV loss |
| | """ |
| | if pred.dim() == 3: |
| | pred = pred.squeeze(-1) |
| | if target is not None and target.dim() == 3: |
| | target = target.squeeze(-1) |
| |
|
| | diff = pred[:, 1:] - pred[:, :-1] |
| | tv_base = torch.pow(torch.abs(diff) + 1e-8, self.beta) |
| |
|
| | if target is None: |
| | return self.lambda_tv * tv_base.mean() |
| |
|
| | left_in_boundary = target[:, :-1] > self.boundary_threshold |
| | right_in_boundary = target[:, 1:] > self.boundary_threshold |
| | near_boundary = left_in_boundary | right_in_boundary |
| | weight_mask = torch.where( |
| | near_boundary, |
| | self.reduction_weight * torch.ones_like(tv_base), |
| | torch.ones_like(tv_base), |
| | ) |
| | tv_weighted = (tv_base * weight_mask).mean() |
| | return self.lambda_tv * tv_weighted |
| |
|
| |
|
| | class SoftmaxFocalLoss(nn.Module): |
| | """ |
| | Softmax Focal Loss for single-label multi-class classification. |
| | Suitable for mutually exclusive classes. |
| | """ |
| |
|
| | def __init__(self, alpha: float = 0.25, gamma: float = 2.0): |
| | super().__init__() |
| | self.alpha = alpha |
| | self.gamma = gamma |
| |
|
| | def forward(self, pred: torch.Tensor, targets: torch.Tensor) -> torch.Tensor: |
| | """ |
| | Args: |
| | pred: [B, T, C], raw logits |
| | targets: [B, T, C] (soft) or [B, T] (hard, dtype=long) |
| | Returns: |
| | loss: scalar or [B, T] depending on reduction |
| | """ |
| | log_probs = F.log_softmax(pred, dim=-1) |
| | probs = torch.exp(log_probs) |
| |
|
| | if targets.dtype == torch.long: |
| | targets_onehot = F.one_hot(targets, num_classes=pred.size(-1)).float() |
| | else: |
| | targets_onehot = targets |
| |
|
| | p_t = (probs * targets_onehot).sum(dim=-1) |
| | p_t = p_t.clamp(min=1e-8, max=1.0 - 1e-8) |
| |
|
| | if self.alpha > 0: |
| | alpha_t = self.alpha * targets_onehot + (1 - self.alpha) * ( |
| | 1 - targets_onehot |
| | ) |
| | alpha_weight = (alpha_t * targets_onehot).sum(dim=-1) |
| | else: |
| | alpha_weight = 1.0 |
| |
|
| | focal_weight = (1 - p_t) ** self.gamma |
| | ce_loss = -log_probs * targets_onehot |
| | ce_loss = ce_loss.sum(dim=-1) |
| |
|
| | loss = alpha_weight * focal_weight * ce_loss |
| | return loss |
| |
|
| |
|
| | class Model(nn.Module): |
| | def __init__(self, config): |
| | super().__init__() |
| | self.config = config |
| |
|
| | self.input_norm = nn.LayerNorm(config.input_dim) |
| | self.mixed_win_downsample = nn.Linear(config.input_dim_raw, config.input_dim) |
| | self.dataset_class_prefix = nn.Embedding( |
| | num_embeddings=config.num_dataset_classes, |
| | embedding_dim=config.transformer_encoder_input_dim, |
| | ) |
| | self.down_sample_conv = TimeDownsample( |
| | dim_in=config.input_dim, |
| | dim_out=config.transformer_encoder_input_dim, |
| | kernel_size=config.down_sample_conv_kernel_size, |
| | stride=config.down_sample_conv_stride, |
| | dropout=config.down_sample_conv_dropout, |
| | padding=config.down_sample_conv_padding, |
| | ) |
| | self.AddFuse = AddFuse() |
| | self.transformer = WrapedTransformerEncoder( |
| | input_dim=config.transformer_encoder_input_dim, |
| | transformer_input_dim=config.transformer_input_dim, |
| | num_layers=config.num_transformer_layers, |
| | nhead=config.transformer_nhead, |
| | dropout=config.transformer_dropout, |
| | ) |
| | self.boundary_TVLoss1D = TVLoss1D( |
| | beta=config.boundary_tv_loss_beta, |
| | lambda_tv=config.boundary_tv_loss_lambda, |
| | boundary_threshold=config.boundary_tv_loss_boundary_threshold, |
| | reduction_weight=config.boundary_tv_loss_reduction_weight, |
| | ) |
| | self.label_focal_loss = SoftmaxFocalLoss( |
| | alpha=config.label_focal_loss_alpha, gamma=config.label_focal_loss_gamma |
| | ) |
| | self.boundary_head = Head(config.transformer_input_dim, 1) |
| | self.function_head = Head(config.transformer_input_dim, config.num_classes) |
| |
|
| | def cal_metrics(self, gt_info: MsaInfo, msa_info: MsaInfo): |
| | assert gt_info[-1][1] == "end" and msa_info[-1][1] == "end", ( |
| | "gt_info and msa_info should end with 'end'" |
| | ) |
| | gt_info_labels = [label for time_, label in gt_info][:-1] |
| | gt_info_inters = [time_ for time_, label in gt_info] |
| | gt_info_inters = np.column_stack( |
| | [np.array(gt_info_inters[:-1]), np.array(gt_info_inters[1:])] |
| | ) |
| |
|
| | msa_info_labels = [label for time_, label in msa_info][:-1] |
| | msa_info_inters = [time_ for time_, label in msa_info] |
| | msa_info_inters = np.column_stack( |
| | [np.array(msa_info_inters[:-1]), np.array(msa_info_inters[1:])] |
| | ) |
| | result = compute_results( |
| | ann_inter=gt_info_inters, |
| | est_inter=msa_info_inters, |
| | ann_labels=gt_info_labels, |
| | est_labels=msa_info_labels, |
| | bins=11, |
| | est_file="test.txt", |
| | weight=0.58, |
| | ) |
| | return result |
| |
|
| | def cal_acc( |
| | self, ann_info: MsaInfo | str, est_info: MsaInfo | str, post_digit: int = 3 |
| | ): |
| | ann_info_time = [ |
| | int(round(time_, post_digit) * (10**post_digit)) |
| | for time_, label in ann_info |
| | ] |
| | est_info_time = [ |
| | int(round(time_, post_digit) * (10**post_digit)) |
| | for time_, label in est_info |
| | ] |
| |
|
| | common_start_time = max(ann_info_time[0], est_info_time[0]) |
| | common_end_time = min(ann_info_time[-1], est_info_time[-1]) |
| |
|
| | time_points = {common_start_time, common_end_time} |
| | time_points.update( |
| | { |
| | time_ |
| | for time_ in ann_info_time |
| | if common_start_time <= time_ <= common_end_time |
| | } |
| | ) |
| | time_points.update( |
| | { |
| | time_ |
| | for time_ in est_info_time |
| | if common_start_time <= time_ <= common_end_time |
| | } |
| | ) |
| |
|
| | time_points = sorted(time_points) |
| | total_duration, total_score = 0, 0 |
| |
|
| | for idx in range(len(time_points) - 1): |
| | duration = time_points[idx + 1] - time_points[idx] |
| | ann_label = ann_info[ |
| | bisect.bisect_right(ann_info_time, time_points[idx]) - 1 |
| | ][1] |
| | est_label = est_info[ |
| | bisect.bisect_right(est_info_time, time_points[idx]) - 1 |
| | ][1] |
| | total_duration += duration |
| | if ann_label == est_label: |
| | total_score += duration |
| | return total_score / total_duration |
| |
|
| | def infer_with_metrics(self, batch, prefix: str = None): |
| | with torch.no_grad(): |
| | logits = self.forward_func(batch) |
| |
|
| | losses = self.compute_losses(logits, batch, prefix=None) |
| |
|
| | expanded_mask = batch["label_id_masks"].expand( |
| | -1, logits["function_logits"].size(1), -1 |
| | ) |
| | logits["function_logits"] = logits["function_logits"].masked_fill( |
| | expanded_mask, -float("inf") |
| | ) |
| |
|
| | msa_info = postprocess_functional_structure( |
| | logits=logits, config=self.config |
| | ) |
| | gt_info = batch["msa_infos"][0] |
| | results = self.cal_metrics(gt_info=gt_info, msa_info=msa_info) |
| |
|
| | ret_results = { |
| | "loss": losses["loss"].item(), |
| | "HitRate_3P": results["HitRate_3P"], |
| | "HitRate_3R": results["HitRate_3R"], |
| | "HitRate_3F": results["HitRate_3F"], |
| | "HitRate_0.5P": results["HitRate_0.5P"], |
| | "HitRate_0.5R": results["HitRate_0.5R"], |
| | "HitRate_0.5F": results["HitRate_0.5F"], |
| | "PWF": results["PWF"], |
| | "PWP": results["PWP"], |
| | "PWR": results["PWR"], |
| | "Sf": results["Sf"], |
| | "So": results["So"], |
| | "Su": results["Su"], |
| | "acc": self.cal_acc(ann_info=gt_info, est_info=msa_info), |
| | } |
| | if prefix: |
| | ret_results = prefix_dict(ret_results, prefix) |
| |
|
| | return ret_results |
| |
|
| | def infer( |
| | self, |
| | input_embeddings, |
| | dataset_ids, |
| | label_id_masks, |
| | prefix: str = None, |
| | with_logits=False, |
| | ): |
| | with torch.no_grad(): |
| | input_embeddings = self.mixed_win_downsample(input_embeddings) |
| | input_embeddings = self.input_norm(input_embeddings) |
| | logits = self.down_sample_conv(input_embeddings) |
| |
|
| | dataset_prefix = self.dataset_class_prefix(dataset_ids) |
| | dataset_prefix_expand = dataset_prefix.unsqueeze(1).expand( |
| | logits.size(0), 1, -1 |
| | ) |
| | logits = self.AddFuse(x=logits, cond=dataset_prefix_expand) |
| | logits = self.transformer(x=logits, src_key_padding_mask=None) |
| |
|
| | function_logits = self.function_head(logits) |
| | boundary_logits = self.boundary_head(logits).squeeze(-1) |
| |
|
| | logits = { |
| | "function_logits": function_logits, |
| | "boundary_logits": boundary_logits, |
| | } |
| |
|
| | expanded_mask = label_id_masks.expand( |
| | -1, logits["function_logits"].size(1), -1 |
| | ) |
| | logits["function_logits"] = logits["function_logits"].masked_fill( |
| | expanded_mask, -float("inf") |
| | ) |
| |
|
| | msa_info = postprocess_functional_structure( |
| | logits=logits, config=self.config |
| | ) |
| |
|
| | return (msa_info, logits) if with_logits else msa_info |
| |
|
| | def compute_losses(self, outputs, batch, prefix: str = None): |
| | loss = 0.0 |
| | losses = {} |
| |
|
| | loss_section = F.binary_cross_entropy_with_logits( |
| | outputs["boundary_logits"], |
| | batch["widen_true_boundaries"], |
| | reduction="none", |
| | ) |
| | loss_section += self.config.boundary_tvloss_weight * self.boundary_TVLoss1D( |
| | pred=outputs["boundary_logits"], |
| | target=batch["widen_true_boundaries"], |
| | ) |
| | loss_function = F.cross_entropy( |
| | outputs["function_logits"].transpose(1, 2), |
| | batch["true_functions"].transpose(1, 2), |
| | reduction="none", |
| | ) |
| | |
| | ttt = self.config.label_focal_loss_weight * self.label_focal_loss( |
| | pred=outputs["function_logits"], targets=batch["true_functions"] |
| | ) |
| | loss_function += ttt |
| |
|
| | float_masks = (~batch["masks"]).float() |
| | boundary_mask = batch.get("boundary_mask", None) |
| | function_mask = batch.get("function_mask", None) |
| | if boundary_mask is not None: |
| | boundary_mask = (~boundary_mask).float() |
| | else: |
| | boundary_mask = 1 |
| |
|
| | if function_mask is not None: |
| | function_mask = (~function_mask).float() |
| | else: |
| | function_mask = 1 |
| |
|
| | loss_section = torch.mean(boundary_mask * float_masks * loss_section) |
| | loss_function = torch.mean(function_mask * float_masks * loss_function) |
| |
|
| | loss_section *= self.config.loss_weight_section |
| | loss_function *= self.config.loss_weight_function |
| |
|
| | if self.config.learn_label: |
| | loss += loss_function |
| | if self.config.learn_segment: |
| | loss += loss_section |
| |
|
| | losses.update( |
| | loss=loss, |
| | loss_section=loss_section, |
| | loss_function=loss_function, |
| | ) |
| | if prefix: |
| | losses = prefix_dict(losses, prefix) |
| | return losses |
| |
|
| | def forward_func(self, batch): |
| | input_embeddings = batch["input_embeddings"] |
| | input_embeddings = self.mixed_win_downsample(input_embeddings) |
| | input_embeddings = self.input_norm(input_embeddings) |
| | logits = self.down_sample_conv(input_embeddings) |
| |
|
| | dataset_prefix = self.dataset_class_prefix(batch["dataset_ids"]) |
| | logits = self.AddFuse(x=logits, cond=dataset_prefix.unsqueeze(1)) |
| | src_key_padding_mask = batch["masks"] |
| | logits = self.transformer(x=logits, src_key_padding_mask=src_key_padding_mask) |
| |
|
| | function_logits = self.function_head(logits) |
| | boundary_logits = self.boundary_head(logits).squeeze(-1) |
| |
|
| | logits = { |
| | "function_logits": function_logits, |
| | "boundary_logits": boundary_logits, |
| | } |
| | return logits |
| |
|
| | def forward(self, batch): |
| | logits = self.forward_func(batch) |
| | losses = self.compute_losses(logits, batch, prefix=None) |
| | return logits, losses["loss"], losses |
| |
|