MTL Peptide Classifier (21 Tasks)

Multi-Task Learning peptide classifier covering 21 binary peptide-activity tasks. Built on a frozen ESM-2 (650M) backbone with a parallel Transformer + CNN feature extractor and per-task heads, following a PDeepPP-inspired design.

Held-out Test Set Performance (Averaged across 21 tasks)

Metric Value
Accuracy 87.37%
F1 84.80%
AUC 92.82%
MCC 73.42%

Best Val Avg F1 (used for checkpoint selection): 83.43%

Per-Task Test Metrics

Task ACC F1 AUC MCC
AntiMRSA 0.9899 0.9667 0.9968 0.9607
Anticancer 0.7267 0.7330 0.8046 0.4540
ACE_inhibitory 0.7798 0.7901 0.8636 0.5623
Antioxidant 0.7758 0.7658 0.8393 0.5508
Bitter 0.8750 0.8689 0.9609 0.7533
Antimalarial 0.9736 0.7600 0.9260 0.7523
Antimicrobial 0.9700 0.9486 0.9824 0.9275
Signal_peptide 0.9919 0.9920 0.9985 0.9839
Antifungal 0.9488 0.9479 0.9851 0.8983
Antimalarial_alt 0.9755 0.9231 0.9976 0.9124
Anticancer_alt 0.9278 0.9275 0.9676 0.8557
Anti_parasitic 0.6957 0.5882 0.8299 0.4587
Umami 0.8202 0.6923 0.9245 0.5697
Quorum_sensing 0.8750 0.8718 0.9650 0.7509
Antibacterial 0.9478 0.9466 0.9739 0.8965
NeuroPred 0.8876 0.8875 0.9478 0.7753
Toxicity 0.9310 0.9254 0.9623 0.8613
Antiviral 0.8483 0.8432 0.9125 0.6981
DPPIV_inhibitory 0.8346 0.8254 0.9420 0.6729
BBP 0.8158 0.7879 0.9224 0.6547
TTCA 0.7563 0.8154 0.7891 0.4688

Architecture

  • Shared encoder: frozen ESM-2 (facebook/esm2_t33_650M_UR50D, 650M params) + learnable base embedding, mixed at esm_ratio=0.9
  • Feature extraction (parallel): 4-layer Transformer + CNN (kernel=7, padding=3) β†’ concatenated to 2560-dim features
  • Heads: 21 binary classifiers (2560 β†’ 256 β†’ 128 β†’ 2) with masked average pooling
  • Loss: TIM (Threshold-Independent Multi-task) loss + label smoothing 0.1

Tasks

# Task Source
1 ACE_inhibitory UniDL4BioPep
2 DPPIV_inhibitory UniDL4BioPep
3 Bitter UniDL4BioPep
4 Umami UniDL4BioPep
5 Antimicrobial UniDL4BioPep
6 Antimalarial (main) UniDL4BioPep
7 Antimalarial_alt UniDL4BioPep
8 Quorum_sensing UniDL4BioPep
9 Anticancer (main) UniDL4BioPep
10 Anticancer_alt UniDL4BioPep
11 AntiMRSA UniDL4BioPep
12 TTCA UniDL4BioPep
13 BBP UniDL4BioPep
14 Anti_parasitic UniDL4BioPep
15 NeuroPred UniDL4BioPep
16 Antibacterial UniDL4BioPep
17 Antifungal UniDL4BioPep
18 Antiviral UniDL4BioPep
19 Toxicity UniDL4BioPep
20 Signal_peptide local dataset
21 Antioxidant UniDL4BioPep (antioxidant_FRS)

Usage

import os
from huggingface_hub import hf_hub_download
import torch
from transformers import EsmTokenizer

from mtl_peptide_classifier import MTLPeptideClassifier, get_all_peptide_tasks

REPO = "minhquoc95/MTL-PepPred"
checkpoint_dir = "MTL-Peptide-Classifier"
os.makedirs(checkpoint_dir, exist_ok=True)

for fname in ["heads.pt", "shared_backbone.pt", "ablation_config.json"]:
    hf_hub_download(repo_id=REPO, filename=fname, local_dir=checkpoint_dir)

tokenizer = EsmTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D")
task_configs = get_all_peptide_tasks("datasets")  # needs local datasets/ dir for task names

model = MTLPeptideClassifier(
    task_configs=task_configs,
    hidden_dim=1280,
    esm_ratio=0.9,
    num_transformer_layers=4,
    dropout=0.3,
    use_transformer=True,
    use_cnn=True,
    unfreeze_esm=False,
)

device = "cuda" if torch.cuda.is_available() else "cpu"
backbone = torch.load(f"{checkpoint_dir}/shared_backbone.pt", map_location=device)
heads = torch.load(f"{checkpoint_dir}/heads.pt", map_location=device)

model.base_embed.load_state_dict(backbone["base_embed"])
if "transformer" in backbone:
    model.transformer.load_state_dict(backbone["transformer"])
if "cnn" in backbone:
    model.cnn.load_state_dict(backbone["cnn"])
    model.layer_norm.load_state_dict(backbone["layer_norm"])
for name, head in model.heads.items():
    if name in heads:
        head.load_state_dict(heads[name])

model = model.to(device).eval()

sequence = "MKWVTFISLLFLFSSAYSRGVFRR"
tokens = " ".join(list(sequence))
inputs = tokenizer(tokens, return_tensors="pt", max_length=128, padding="max_length", truncation=True)
with torch.no_grad():
    logits = model(inputs["input_ids"].to(device), inputs["attention_mask"].to(device), task_name="Antimicrobial")
    probs = torch.softmax(logits, dim=-1)

Training

  • Base model: facebook/esm2_t33_650M_UR50D (frozen)
  • Batch size: 16, learning rate: 1e-4, 50 epochs, dropout: 0.3
  • 3-way split per task: 80% train / 20% val (checkpoint selection) / held-out test CSV evaluated once
  • Mixed precision, gradient clipping 1.0, cosine LR with 5 warmup epochs
  • TIM loss + label smoothing 0.1

Files

  • heads.pt β€” per-task classifier heads
  • shared_backbone.pt β€” base embedding, Transformer, CNN, LayerNorm
  • ablation_config.json β€” architecture configuration for reproducibility
  • test_results.json β€” held-out test metrics (per task + averages)
  • mtl_peptide_classifier.py β€” model code

Requirements

torch>=2.0.0
transformers>=4.30.0
huggingface_hub
numpy
pandas
scikit-learn
Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. πŸ™‹ Ask for provider support

Model tree for minhquoc95/MTL-PepPred

Finetuned
(33)
this model