| | import torch |
| | from transformers import AutoTokenizer, BatchEncoding |
| |
|
| | from mixinhelpers import CXR_Mixin, ECG_Mixin, ECHO_Mixin, Text_Mixin |
| |
|
| | """ |
| | Preprocessor classes for different modalities and their combinations. |
| | You can combine different mixins to create preprocessors for multi-modal inputs. |
| | Examples below are provided for ECHO+Text, ECG+Text, and CXR+Text. |
| | """ |
| |
|
| |
|
| | class BasePreprocessor: |
| | def __init__(self, model_name: str = "dmis-lab/biobert-v1.1") -> None: |
| | self.tokenizer = AutoTokenizer.from_pretrained(model_name) |
| |
|
| |
|
| | |
| | class ECHOText_Preprocessor(BasePreprocessor, ECHO_Mixin, Text_Mixin): |
| | def __init__(self, model_name: str = "dmis-lab/biobert-v1.1") -> None: |
| | super().__init__(model_name=model_name) |
| |
|
| | def preprocess_echo_text(self, echo_path: str, text: str) -> tuple[torch.Tensor, BatchEncoding]: |
| | """this can be used in dataloader to correctly collate batches, use the string keys to |
| | identify the modalities |
| | echo_path: path to echo npy file |
| | text: string of text report |
| | returns: (echo tensor, tokenized text dict)""" |
| | echo = self.preprocess_single_echo(echo_path) |
| | text_inputs = self.construct_caption( |
| | caption=text, tokenizer=self.tokenizer, modality=self.ECHO_KEY |
| | ) |
| | return echo, text_inputs |
| |
|
| |
|
| | class ECGText_Preprocessor(BasePreprocessor, ECG_Mixin, Text_Mixin): |
| | def __init__(self, model_name: str = "dmis-lab/biobert-v1.1") -> None: |
| | super().__init__(model_name=model_name) |
| |
|
| | def preprocess_ecg_text(self, ecg_path: str, text: str) -> tuple[torch.Tensor, BatchEncoding]: |
| | """this can be used in dataloader to correctly collate batches, use the string keys |
| | to identify the modalities |
| | ecg_path: path to ecg npy file |
| | text: string of text report |
| | returns: (ecg tensor, tokenized text dict)""" |
| | ecg = self.preprocess_single_ecg(ecg_path) |
| | text_inputs = self.construct_caption( |
| | caption=text, tokenizer=self.tokenizer, modality=self.ECG_KEY |
| | ) |
| |
|
| | return ecg, text_inputs |
| |
|
| |
|
| | class CXRText_Preprocessor(BasePreprocessor, CXR_Mixin, Text_Mixin): |
| | def __init__(self, model_name: str = "dmis-lab/biobert-v1.1") -> None: |
| | super().__init__(model_name=model_name) |
| |
|
| | def preprocess_cxr_text(self, cxr_path: str, text: str) -> tuple[torch.Tensor, BatchEncoding]: |
| | """this can be used in dataloader to correctly collate batches, use the string keys to |
| | identify the modalities |
| | cxr_path: path to cxr image file |
| | text: string of text report |
| | returns: (cxr tensor, tokenized text dict)""" |
| | cxr = self.preprocess_single_cxr(cxr_path) |
| | text_inputs = self.construct_caption( |
| | caption=text, tokenizer=self.tokenizer, modality=self.VISION_KEY |
| | ) |
| |
|
| | return cxr, text_inputs |
| |
|