| | --- |
| | license: mit |
| | tags: |
| | - braindecode |
| | --- |
| | |
| | Shallow conversion from the original weight for braindecode. |
| |
|
| | ```python |
| | |
| | #!/usr/bin/env python3 |
| | """ |
| | Complete LaBraM Weight Transfer Script |
| | |
| | Combines explicit weight mapping with full backbone transfer. |
| | Uses precise key renaming to transfer all compatible parameters. |
| | |
| | Transfers weights from LaBraM checkpoint to Braindecode Labram model. |
| | """ |
| | |
| | import torch |
| | import argparse |
| | from braindecode.models import Labram |
| | |
| | |
| | def create_weight_mapping(): |
| | """ |
| | Create comprehensive weight mapping from LaBraM to Braindecode. |
| | |
| | Includes: |
| | - Temporal convolution layers (patch_embed) |
| | - All transformer blocks |
| | - Position embeddings |
| | - Other backbone components |
| | """ |
| | return { |
| | # Temporal Convolution Layers |
| | 'student.patch_embed.conv1.weight': 'patch_embed.temporal_conv.conv1.weight', |
| | 'student.patch_embed.conv1.bias': 'patch_embed.temporal_conv.conv1.bias', |
| | 'student.patch_embed.norm1.weight': 'patch_embed.temporal_conv.norm1.weight', |
| | 'student.patch_embed.norm1.bias': 'patch_embed.temporal_conv.norm1.bias', |
| | 'student.patch_embed.conv2.weight': 'patch_embed.temporal_conv.conv2.weight', |
| | 'student.patch_embed.conv2.bias': 'patch_embed.temporal_conv.conv2.bias', |
| | 'student.patch_embed.norm2.weight': 'patch_embed.temporal_conv.norm2.weight', |
| | 'student.patch_embed.norm2.bias': 'patch_embed.temporal_conv.norm2.bias', |
| | 'student.patch_embed.conv3.weight': 'patch_embed.temporal_conv.conv3.weight', |
| | 'student.patch_embed.conv3.bias': 'patch_embed.temporal_conv.conv3.bias', |
| | 'student.patch_embed.norm3.weight': 'patch_embed.temporal_conv.norm3.weight', |
| | 'student.patch_embed.norm3.bias': 'patch_embed.temporal_conv.norm3.bias', |
| | # Note: Other backbone layers (blocks, embeddings, norm, fc_norm) are handled |
| | # by removing 'student.' prefix in process_state_dict() |
| | } |
| | |
| | |
| | def process_state_dict(state_dict, weight_mapping): |
| | """ |
| | Process checkpoint state dict with explicit mapping. |
| | |
| | Parameters: |
| | ----------- |
| | state_dict : dict |
| | Original checkpoint state dictionary |
| | weight_mapping : dict |
| | Explicit mapping for special layers (patch_embed) |
| | |
| | Returns: |
| | -------- |
| | dict : Processed state dict ready for Braindecode model |
| | """ |
| | new_state = {} |
| | mapped_keys = [] |
| | skipped_keys = [] |
| | |
| | for key, value in state_dict.items(): |
| | # Skip classification head (task-specific) |
| | if 'head' in key: |
| | skipped_keys.append((key, 'head layer')) |
| | continue |
| | |
| | # Use explicit mapping for patch_embed temporal_conv |
| | if key in weight_mapping: |
| | new_key = weight_mapping[key] |
| | new_state[new_key] = value |
| | mapped_keys.append((key, new_key)) |
| | continue |
| | |
| | # Skip original patch_embed if not in mapping (SegmentPatch) |
| | if 'patch_embed' in key and 'temporal_conv' not in key: |
| | skipped_keys.append((key, 'patch_embed (non-temporal)')) |
| | continue |
| | |
| | # For backbone layers, remove 'student.' prefix |
| | if key.startswith('student.'): |
| | new_key = key.replace('student.', '') |
| | new_state[new_key] = value |
| | mapped_keys.append((key, new_key)) |
| | continue |
| | |
| | # Keep other keys as-is |
| | new_state[key] = value |
| | mapped_keys.append((key, key)) |
| | |
| | return new_state, mapped_keys, skipped_keys |
| | |
| | |
| | def transfer_labram_weights( |
| | checkpoint_path, |
| | n_times=1600, |
| | n_chans=64, |
| | n_outputs=4, |
| | output_path=None, |
| | verbose=True |
| | ): |
| | """ |
| | Transfer LaBraM weights to Braindecode Labram using explicit mapping. |
| | |
| | Parameters: |
| | ----------- |
| | checkpoint_path : str |
| | Path to LaBraM checkpoint |
| | n_times : int |
| | Number of time samples |
| | n_chans : int |
| | Number of channels |
| | n_outputs : int |
| | Number of output classes |
| | output_path : str |
| | Where to save the model |
| | verbose : bool |
| | Print transfer details |
| | |
| | Returns: |
| | -------- |
| | model : Labram |
| | Model with transferred weights |
| | stats : dict |
| | Transfer statistics |
| | """ |
| | |
| | print("\n" + "="*70) |
| | print("LaBraM → Braindecode Weight Transfer") |
| | print("="*70) |
| | |
| | # Load checkpoint |
| | print(f"\nLoading checkpoint: {checkpoint_path}") |
| | checkpoint = torch.load(checkpoint_path, map_location='cpu', weights_only=False) |
| | |
| | # Extract model state |
| | if isinstance(checkpoint, dict) and 'model' in checkpoint: |
| | state = checkpoint['model'] |
| | else: |
| | state = checkpoint |
| | |
| | original_params = len(state) |
| | print(f"Original checkpoint: {original_params} parameters") |
| | |
| | # Create weight mapping |
| | weight_mapping = create_weight_mapping() |
| | |
| | # Process state dict |
| | print("\nProcessing checkpoint...") |
| | new_state, mapped_keys, skipped_keys = process_state_dict(state, weight_mapping) |
| | |
| | transferred_params = len(mapped_keys) |
| | print(f"Mapped keys: {transferred_params} ({transferred_params/original_params*100:.1f}%)") |
| | print(f"Skipped keys: {len(skipped_keys)}") |
| | |
| | if verbose and skipped_keys: |
| | print(f"\nSkipped layers:") |
| | for key, reason in skipped_keys[:5]: # Show first 5 |
| | print(f" - {key:50s} ({reason})") |
| | if len(skipped_keys) > 5: |
| | print(f" ... and {len(skipped_keys) - 5} more") |
| | |
| | # Create model |
| | print(f"\nCreating Labram model:") |
| | print(f" n_times: {n_times}") |
| | print(f" n_chans: {n_chans}") |
| | print(f" n_outputs: {n_outputs}") |
| | model = Labram( |
| | n_times=n_times, |
| | n_chans=n_chans, |
| | n_outputs=n_outputs, |
| | neural_tokenizer=True, |
| | ) |
| | |
| | # Load weights |
| | print("\nLoading weights into model...") |
| | incompatible = model.load_state_dict(new_state, strict=False) |
| | |
| | missing_count = len(incompatible.missing_keys) if incompatible.missing_keys else 0 |
| | unexpected_count = len(incompatible.unexpected_keys) if incompatible.unexpected_keys else 0 |
| | |
| | if missing_count > 0: |
| | print(f" Missing keys: {missing_count} (expected - will be initialized)") |
| | if unexpected_count > 0: |
| | print(f" Unexpected keys: {unexpected_count}") |
| | |
| | # Test forward pass |
| | if verbose: |
| | print("\nTesting forward pass...") |
| | x = torch.randn(2, n_chans, n_times) |
| | with torch.no_grad(): |
| | output = model(x) |
| | print(f" Input shape: {x.shape}") |
| | print(f" Output shape: {output.shape}") |
| | print(" ✅ Forward pass successful!") |
| | |
| | # Save model if output_path provided |
| | if output_path: |
| | print(f"\nSaving model to: {output_path}") |
| | torch.save(model.state_dict(), output_path) |
| | print(f" ✅ Model saved") |
| | |
| | stats = { |
| | 'original': original_params, |
| | 'transferred': transferred_params, |
| | 'skipped': len(skipped_keys), |
| | 'transfer_rate': f"{transferred_params/original_params*100:.1f}%" |
| | } |
| | |
| | return model, stats |
| | |
| | |
| | if __name__ == '__main__': |
| | parser = argparse.ArgumentParser( |
| | description='Transfer LaBraM weights to Braindecode Labram', |
| | formatter_class=argparse.RawDescriptionHelpFormatter, |
| | epilog=""" |
| | Examples: |
| | # Default transfer (backbone parameters) |
| | python labram_complete_transfer.py |
| | |
| | # Transfer and save model |
| | python labram_complete_transfer.py --output labram_weights.pt |
| | |
| | # Custom EEG parameters |
| | python labram_complete_transfer.py --n-times 2000 --n-chans 62 --n-outputs 2 |
| | |
| | # Custom checkpoint path |
| | python labram_complete_transfer.py --checkpoint path/to/checkpoint.pth |
| | """ |
| | ) |
| | |
| | parser.add_argument( |
| | '--checkpoint', |
| | type=str, |
| | default='LaBraM/checkpoints/labram-base.pth', |
| | help='Path to LaBraM checkpoint (default: LaBraM/checkpoints/labram-base.pth)' |
| | ) |
| | parser.add_argument( |
| | '--n-times', |
| | type=int, |
| | default=1600, |
| | help='Number of time samples (default: 1600)' |
| | ) |
| | parser.add_argument( |
| | '--n-chans', |
| | type=int, |
| | default=64, |
| | help='Number of channels (default: 64)' |
| | ) |
| | parser.add_argument( |
| | '--n-outputs', |
| | type=int, |
| | default=4, |
| | help='Number of output classes (default: 4)' |
| | ) |
| | parser.add_argument( |
| | '--output', |
| | type=str, |
| | default=None, |
| | help='Output file path to save model weights' |
| | ) |
| | parser.add_argument( |
| | '--device', |
| | type=str, |
| | default='cpu', |
| | help='Device to use (default: cpu)' |
| | ) |
| | |
| | args = parser.parse_args() |
| | |
| | print("="*70) |
| | print("LaBraM → Braindecode Weight Transfer") |
| | print("="*70) |
| | |
| | # Transfer weights |
| | model, stats = transfer_labram_weights( |
| | checkpoint_path=args.checkpoint, |
| | n_times=args.n_times, |
| | n_chans=args.n_chans, |
| | n_outputs=args.n_outputs, |
| | output_path=args.output, |
| | verbose=True |
| | ) |
| | |
| | print("\n" + "="*70) |
| | print("✅ TRANSFER COMPLETE") |
| | print("="*70) |
| | print(f"Original parameters: {stats['original']}") |
| | print(f"Transferred: {stats['transferred']} ({stats['transfer_rate']})") |
| | print(f"Skipped: {stats['skipped']}") |
| | print("="*70) |
| | |
| | ``` |