| """ |
| Legal-Longformer Model Architecture - Fully Learning-Based |
| Includes Hierarchical Longformer for document-level understanding |
| """ |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from transformers import AutoModel, AutoTokenizer |
| from typing import Dict, List, Any, Optional, Tuple |
|
|
| class FullyLearningBasedLegalBERT(nn.Module): |
| """ |
| Legal-Longformer model that learns from discovered risk patterns. |
| NO hardcoded risk categories! |
| """ |
| |
| def __init__(self, config, num_discovered_risks: int = 7): |
| super().__init__() |
| self.config = config |
| self.num_discovered_risks = num_discovered_risks |
| |
| |
| try: |
| self.bert = AutoModel.from_pretrained(config.bert_model_name) |
| |
| self.bert.config.hidden_dropout_prob = config.dropout_rate |
| self.bert.config.attention_probs_dropout_prob = config.dropout_rate |
| |
| hidden_size = self.bert.config.hidden_size |
| |
| |
| if getattr(config, 'use_gradient_checkpointing', False): |
| self.bert.gradient_checkpointing_enable() |
| print("✅ Gradient checkpointing enabled - trading computation for memory") |
| except: |
| |
| print("⚠️ Warning: Using mock Longformer model (transformers not available)") |
| self.bert = None |
| hidden_size = 768 |
| |
| |
| |
| |
| self.risk_classifier = nn.Sequential( |
| nn.Dropout(config.dropout_rate), |
| nn.Linear(hidden_size, hidden_size // 2), |
| nn.ReLU(), |
| nn.Dropout(config.dropout_rate), |
| nn.Linear(hidden_size // 2, num_discovered_risks) |
| ) |
| |
| |
| self.severity_regressor = nn.Sequential( |
| nn.Dropout(config.dropout_rate), |
| nn.Linear(hidden_size, hidden_size // 4), |
| nn.ReLU(), |
| nn.Dropout(config.dropout_rate), |
| nn.Linear(hidden_size // 4, 1), |
| nn.Sigmoid() |
| ) |
| |
| |
| self.importance_regressor = nn.Sequential( |
| nn.Dropout(config.dropout_rate), |
| nn.Linear(hidden_size, hidden_size // 4), |
| nn.ReLU(), |
| nn.Dropout(config.dropout_rate), |
| nn.Linear(hidden_size // 4, 1), |
| nn.Sigmoid() |
| ) |
| |
| |
| self.temperature = nn.Parameter(torch.ones(1)) |
| |
| def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, |
| output_attentions: bool = False) -> Dict[str, torch.Tensor]: |
| """Forward pass through the model |
| |
| Args: |
| input_ids: Token IDs from tokenizer |
| attention_mask: Attention mask for valid tokens |
| output_attentions: If True, return attention weights for analysis |
| """ |
| |
| if self.bert is not None: |
| |
| outputs = self.bert( |
| input_ids=input_ids, |
| attention_mask=attention_mask, |
| output_attentions=output_attentions |
| ) |
| |
| pooled_output = outputs.pooler_output if hasattr(outputs, 'pooler_output') and outputs.pooler_output is not None else outputs.last_hidden_state[:, 0, :] |
| attentions = outputs.attentions if output_attentions else None |
| else: |
| |
| batch_size = input_ids.size(0) |
| pooled_output = torch.randn(batch_size, 768) |
| if input_ids.is_cuda: |
| pooled_output = pooled_output.cuda() |
| attentions = None |
| |
| |
| risk_logits = self.risk_classifier(pooled_output) |
| severity_score = self.severity_regressor(pooled_output).squeeze(-1) * 10 |
| importance_score = self.importance_regressor(pooled_output).squeeze(-1) * 10 |
| |
| |
| calibrated_logits = risk_logits / self.temperature |
| |
| result = { |
| 'risk_logits': risk_logits, |
| 'calibrated_logits': calibrated_logits, |
| 'severity_score': severity_score, |
| 'importance_score': importance_score, |
| 'pooled_output': pooled_output |
| } |
| |
| if output_attentions and attentions is not None: |
| result['attentions'] = attentions |
| |
| return result |
| |
| def predict_risk_pattern(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, |
| return_attentions: bool = False) -> Dict[str, Any]: |
| """Make predictions and return interpretable results |
| |
| Args: |
| input_ids: Token IDs from tokenizer |
| attention_mask: Attention mask for valid tokens |
| return_attentions: If True, include attention weights for analysis |
| """ |
| self.eval() |
| |
| with torch.no_grad(): |
| outputs = self.forward(input_ids, attention_mask, output_attentions=return_attentions) |
| |
| |
| risk_probs = torch.softmax(outputs['calibrated_logits'], dim=-1) |
| predicted_risk = torch.argmax(risk_probs, dim=-1) |
| confidence = torch.max(risk_probs, dim=-1)[0] |
| |
| result = { |
| 'predicted_risk_id': predicted_risk.cpu().numpy(), |
| 'risk_probabilities': risk_probs.cpu().numpy(), |
| 'confidence': confidence.cpu().numpy(), |
| 'severity_score': outputs['severity_score'].cpu().numpy(), |
| 'importance_score': outputs['importance_score'].cpu().numpy() |
| } |
| |
| if return_attentions and 'attentions' in outputs: |
| result['attentions'] = outputs['attentions'] |
| |
| return result |
| |
| def analyze_attention(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, |
| tokenizer: Optional['LegalBertTokenizer'] = None) -> Dict[str, Any]: |
| """Analyze attention patterns to identify important tokens for risk assessment |
| |
| This method extracts and analyzes BERT attention weights to determine which |
| tokens/words contribute most to the risk prediction. Useful for interpretability. |
| |
| Args: |
| input_ids: Token IDs from tokenizer |
| attention_mask: Attention mask for valid tokens |
| tokenizer: Tokenizer to decode tokens (optional) |
| |
| Returns: |
| Dictionary containing: |
| - token_importance: Per-token importance scores |
| - top_tokens: Most important tokens for prediction |
| - attention_weights: Raw attention weights from last layer |
| - layer_analysis: Attention analysis per layer |
| """ |
| self.eval() |
| |
| with torch.no_grad(): |
| outputs = self.forward(input_ids, attention_mask, output_attentions=True) |
| |
| if 'attentions' not in outputs or outputs['attentions'] is None: |
| return {'error': 'Attention weights not available'} |
| |
| attentions = outputs['attentions'] |
| batch_size, seq_len = input_ids.shape |
| |
| |
| |
| all_attentions = torch.stack(attentions) |
| |
| |
| |
| cls_attention = all_attentions[:, :, :, 0, :].mean(dim=[0, 2]) |
| |
| |
| global_attention = all_attentions.mean(dim=[0, 2, 3]) |
| |
| |
| token_importance = (cls_attention + global_attention) / 2 |
| |
| |
| token_importance = token_importance * attention_mask |
| |
| |
| k = min(10, seq_len) |
| top_values, top_indices = torch.topk(token_importance, k, dim=1) |
| |
| result = { |
| 'token_importance': token_importance.cpu().numpy(), |
| 'top_token_indices': top_indices.cpu().numpy(), |
| 'top_token_scores': top_values.cpu().numpy(), |
| 'attention_weights': { |
| 'cls_attention': cls_attention.cpu().numpy(), |
| 'global_attention': global_attention.cpu().numpy() |
| } |
| } |
| |
| |
| layer_attentions = [] |
| for layer_idx, layer_attn in enumerate(attentions): |
| |
| layer_cls_attn = layer_attn[:, :, 0, :].mean(dim=1) |
| layer_attentions.append({ |
| 'layer': layer_idx, |
| 'cls_attention': layer_cls_attn.cpu().numpy() |
| }) |
| result['layer_analysis'] = layer_attentions |
| |
| |
| if tokenizer is not None and tokenizer.tokenizer is not None: |
| tokens = tokenizer.tokenizer.convert_ids_to_tokens(input_ids[0]) |
| top_tokens = [tokens[idx] for idx in top_indices[0].cpu().numpy()] |
| result['tokens'] = tokens |
| result['top_tokens'] = top_tokens |
| |
| return result |
|
|
| class LegalBertTokenizer: |
| """Tokenizer wrapper for Legal-Longformer""" |
| |
| def __init__(self, model_name: str = "allenai/longformer-base-4096"): |
| try: |
| self.tokenizer = AutoTokenizer.from_pretrained(model_name) |
| except: |
| print("⚠️ Warning: Using mock tokenizer (transformers not available)") |
| self.tokenizer = None |
| |
| def tokenize_clauses(self, clauses: List[str], max_length: int = 512) -> Dict[str, torch.Tensor]: |
| """Tokenize legal clauses for model input""" |
| |
| if self.tokenizer is None: |
| |
| batch_size = len(clauses) |
| return { |
| 'input_ids': torch.randint(0, 1000, (batch_size, max_length)), |
| 'attention_mask': torch.ones(batch_size, max_length) |
| } |
| |
| |
| encoded = self.tokenizer( |
| clauses, |
| padding=True, |
| truncation=True, |
| max_length=max_length, |
| return_tensors='pt' |
| ) |
| |
| return { |
| 'input_ids': encoded['input_ids'], |
| 'attention_mask': encoded['attention_mask'] |
| } |
| |
| def decode_tokens(self, token_ids: torch.Tensor) -> List[str]: |
| """Decode token IDs back to text""" |
| if self.tokenizer is None: |
| return ["Mock decoded text"] * token_ids.size(0) |
| |
| return self.tokenizer.batch_decode(token_ids, skip_special_tokens=True) |
|
|
|
|
| |
| |
| |
|
|
| class HierarchicalLegalBERT(nn.Module): |
| """ |
| Hierarchical Longformer for document-level contract understanding |
| |
| **Key Innovation**: Processes documents hierarchically to maintain context |
| |
| Architecture: |
| Clause Encoding (Longformer) → Section Aggregation (LSTM+Attention) → Document |
| |
| Solves the context problem: |
| - Your current model: Each clause processed independently ❌ |
| - This model: Clauses processed WITH section context ✅ |
| |
| Usage: |
| # Training: Same as current model (clause-level labels) |
| # Inference: Processes full documents with context |
| |
| document = [ |
| ['clause1', 'clause2'], # Section 1 |
| ['clause3', 'clause4'], # Section 2 |
| ] |
| results = model.predict_document(document) |
| """ |
| |
| def __init__( |
| self, |
| config, |
| num_discovered_risks: int = 7, |
| hidden_dim: int = 256, |
| num_lstm_layers: int = 2 |
| ): |
| super().__init__() |
| self.config = config |
| self.num_discovered_risks = num_discovered_risks |
| self.hidden_dim = hidden_dim |
| |
| |
| try: |
| self.bert = AutoModel.from_pretrained(config.bert_model_name) |
| self.bert.config.hidden_dropout_prob = config.dropout_rate |
| self.bert.config.attention_probs_dropout_prob = config.dropout_rate |
| self.bert_hidden_size = self.bert.config.hidden_size |
| |
| |
| if getattr(config, 'use_gradient_checkpointing', False): |
| self.bert.gradient_checkpointing_enable() |
| print("✅ Gradient checkpointing enabled in Hierarchical model") |
| except: |
| print("⚠️ Warning: Using mock Longformer model") |
| self.bert = None |
| self.bert_hidden_size = 768 |
| |
| |
| |
| self.clause_to_section = nn.LSTM( |
| input_size=self.bert_hidden_size, |
| hidden_size=hidden_dim, |
| num_layers=num_lstm_layers, |
| bidirectional=True, |
| dropout=config.dropout_rate if num_lstm_layers > 1 else 0, |
| batch_first=True |
| ) |
| |
| |
| self.section_to_document = nn.LSTM( |
| input_size=hidden_dim * 2, |
| hidden_size=hidden_dim, |
| num_layers=num_lstm_layers, |
| bidirectional=True, |
| dropout=config.dropout_rate if num_lstm_layers > 1 else 0, |
| batch_first=True |
| ) |
| |
| |
| self.clause_attention = nn.Sequential( |
| nn.Linear(hidden_dim * 2, hidden_dim), |
| nn.Tanh(), |
| nn.Dropout(config.dropout_rate), |
| nn.Linear(hidden_dim, 1) |
| ) |
| |
| self.section_attention = nn.Sequential( |
| nn.Linear(hidden_dim * 2, hidden_dim), |
| nn.Tanh(), |
| nn.Dropout(config.dropout_rate), |
| nn.Linear(hidden_dim, 1) |
| ) |
| |
| |
| |
| self.risk_classifier = nn.Sequential( |
| nn.Dropout(config.dropout_rate), |
| nn.Linear(hidden_dim * 2, hidden_dim), |
| nn.ReLU(), |
| nn.Dropout(config.dropout_rate), |
| nn.Linear(hidden_dim, num_discovered_risks) |
| ) |
| |
| self.severity_regressor = nn.Sequential( |
| nn.Dropout(config.dropout_rate), |
| nn.Linear(hidden_dim * 2, hidden_dim // 2), |
| nn.ReLU(), |
| nn.Dropout(config.dropout_rate), |
| nn.Linear(hidden_dim // 2, 1), |
| nn.Sigmoid() |
| ) |
| |
| self.importance_regressor = nn.Sequential( |
| nn.Dropout(config.dropout_rate), |
| nn.Linear(hidden_dim * 2, hidden_dim // 2), |
| nn.ReLU(), |
| nn.Dropout(config.dropout_rate), |
| nn.Linear(hidden_dim // 2, 1), |
| nn.Sigmoid() |
| ) |
| |
| self.temperature = nn.Parameter(torch.ones(1)) |
| |
| def encode_clause(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor: |
| """Encode a single clause with Longformer""" |
| if self.bert is not None: |
| outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask) |
| |
| if hasattr(outputs, 'pooler_output') and outputs.pooler_output is not None: |
| return outputs.pooler_output |
| else: |
| return outputs.last_hidden_state[:, 0, :] |
| else: |
| batch_size = input_ids.size(0) |
| return torch.randn(batch_size, self.bert_hidden_size).to(input_ids.device) |
| |
| def forward_single_clause( |
| self, |
| input_ids: torch.Tensor, |
| attention_mask: torch.Tensor |
| ) -> Dict[str, torch.Tensor]: |
| """ |
| Forward pass for SINGLE clause (for training compatibility) |
| |
| This maintains compatibility with your current training pipeline |
| where clauses are processed one at a time during training. |
| """ |
| |
| clause_embedding = self.encode_clause(input_ids, attention_mask) |
| |
| |
| |
| lstm_out, _ = self.clause_to_section(clause_embedding.unsqueeze(1)) |
| context_aware_repr = lstm_out.squeeze(1) |
| |
| |
| risk_logits = self.risk_classifier(context_aware_repr) |
| severity_score = self.severity_regressor(context_aware_repr).squeeze(-1) * 10 |
| importance_score = self.importance_regressor(context_aware_repr).squeeze(-1) * 10 |
| calibrated_logits = risk_logits / self.temperature |
| |
| return { |
| 'risk_logits': risk_logits, |
| 'calibrated_logits': calibrated_logits, |
| 'severity_score': severity_score, |
| 'importance_score': importance_score, |
| 'pooled_output': context_aware_repr |
| } |
| |
| def forward_document( |
| self, |
| document_structure: List[List[Dict[str, torch.Tensor]]] |
| ) -> Dict[str, Any]: |
| """ |
| Forward pass for FULL DOCUMENT (for inference with context) |
| |
| Args: |
| document_structure: List of sections, each containing list of clause inputs |
| Example: [ |
| [ # Section 1 |
| {'input_ids': tensor, 'attention_mask': tensor}, |
| {'input_ids': tensor, 'attention_mask': tensor} |
| ], |
| [ # Section 2 |
| {'input_ids': tensor, 'attention_mask': tensor} |
| ] |
| ] |
| |
| Returns: |
| Document-level predictions with full context |
| """ |
| device = next(self.parameters()).device |
| section_vectors = [] |
| all_clause_predictions = [] |
| attention_weights = {'clause': [], 'section': None} |
| |
| |
| for section_idx, section_clauses in enumerate(document_structure): |
| if not section_clauses: |
| continue |
| |
| |
| clause_embeddings = [] |
| for clause_input in section_clauses: |
| input_ids = clause_input['input_ids'].unsqueeze(0).to(device) |
| attention_mask = clause_input['attention_mask'].unsqueeze(0).to(device) |
| clause_emb = self.encode_clause(input_ids, attention_mask) |
| clause_embeddings.append(clause_emb) |
| |
| |
| clause_hidden = torch.cat(clause_embeddings, dim=0) |
| |
| |
| clause_lstm_out, _ = self.clause_to_section(clause_hidden.unsqueeze(0)) |
| |
| |
| |
| attention_logits = self.clause_attention(clause_lstm_out) |
| clause_attn = F.softmax(attention_logits, dim=1) |
| section_vec = torch.sum(clause_lstm_out * clause_attn, dim=1) |
| |
| section_vectors.append(section_vec) |
| attention_weights['clause'].append(clause_attn.squeeze(0)) |
| |
| |
| for i in range(len(section_clauses)): |
| clause_repr = clause_lstm_out[0, i, :] |
| |
| risk_logits = self.risk_classifier(clause_repr) |
| severity = self.severity_regressor(clause_repr).squeeze() * 10 |
| importance = self.importance_regressor(clause_repr).squeeze() * 10 |
| calibrated_logits = risk_logits / self.temperature |
| |
| all_clause_predictions.append({ |
| 'risk_logits': risk_logits, |
| 'calibrated_logits': calibrated_logits, |
| 'severity_score': severity, |
| 'importance_score': importance, |
| 'section_idx': section_idx, |
| 'clause_idx': i |
| }) |
| |
| |
| if section_vectors: |
| section_hidden = torch.cat(section_vectors, dim=0) |
| section_lstm_out, _ = self.section_to_document(section_hidden.unsqueeze(0)) |
| |
| attention_logits = self.section_attention(section_lstm_out) |
| section_attn = F.softmax(attention_logits, dim=1) |
| document_vec = torch.sum(section_lstm_out * section_attn, dim=1) |
| |
| attention_weights['section'] = section_attn.squeeze(0) |
| else: |
| document_vec = torch.zeros(1, self.hidden_dim * 2).to(device) |
| |
| return { |
| 'document_embedding': document_vec, |
| 'clause_predictions': all_clause_predictions, |
| 'attention_weights': attention_weights |
| } |
| |
| def predict_document( |
| self, |
| document_structure: List[List[Dict[str, torch.Tensor]]] |
| ) -> Dict[str, Any]: |
| """Inference mode with formatted output""" |
| self.eval() |
| |
| with torch.no_grad(): |
| outputs = self.forward_document(document_structure) |
| |
| |
| predictions = [] |
| for pred in outputs['clause_predictions']: |
| risk_probs = F.softmax(pred['calibrated_logits'], dim=0).cpu().numpy() |
| predicted_risk = int(risk_probs.argmax()) |
| |
| predictions.append({ |
| 'section_idx': pred['section_idx'], |
| 'clause_idx': pred['clause_idx'], |
| 'predicted_risk_id': predicted_risk, |
| 'risk_probabilities': risk_probs.tolist(), |
| 'confidence': float(risk_probs[predicted_risk]), |
| 'severity_score': pred['severity_score'].item(), |
| 'importance_score': pred['importance_score'].item() |
| }) |
| |
| return { |
| 'clauses': predictions, |
| 'attention_weights': { |
| 'clause': [attn.cpu().numpy().tolist() for attn in outputs['attention_weights']['clause']], |
| 'section': outputs['attention_weights']['section'].cpu().numpy().tolist() |
| if outputs['attention_weights']['section'] is not None else None |
| }, |
| 'summary': { |
| 'num_sections': len(document_structure), |
| 'num_clauses': len(predictions), |
| 'avg_severity': sum(p['severity_score'] for p in predictions) / len(predictions) if predictions else 0, |
| 'high_risk_count': sum(1 for p in predictions if p['severity_score'] > 7) |
| } |
| } |