Source code for pyhealth.models.medfuse

# Authors: Sean Nian, Tony Hong, Yaqi Qiao
# Description: MedFuse model implementation for PyHealth.

from __future__ import annotations

from typing import Dict, Optional, Tuple, cast

import torch
import torch.nn as nn
import torch.nn.utils.rnn as rnn_utils

from pyhealth.datasets import SampleDataset
from pyhealth.models.base_model import BaseModel


[docs]class MedFuseLayer(nn.Module): """MedFuse fusion layer. Fuses EHR time-series and chest X-ray (CXR) representations using an LSTM-based sequential fusion strategy that supports missing CXR modality. Follows Paper ยง3.2, Eq. 4: the fusion LSTM consumes the token sequence ``v_fusion = [v_ehr, v*_cxr]`` (length 1 when CXR is missing). The EHR representation is the first input token, not an initial hidden state. Paper: Hayat, N., Geras, K. J., & Shamout, F. E. (2022). MedFuse: Multi-modal fusion with clinical time-series data and chest X-ray images. MLHC 2022. Args: ehr_input_dim: Dimension of EHR features at each timestep. ehr_hidden_dim: Hidden dimension of the EHR LSTM encoder. Default is 256. ehr_num_layers: Number of stacked LSTM layers for EHR. Default is 2. cxr_backbone: ResNet variant for CXR encoder. Default is ``"resnet34"``. cxr_pretrained: Whether to use pretrained CXR encoder weights. Default is True. fusion_hidden_dim: Hidden dimension of fusion LSTM. Default is 512. projection_dim: Dimension to project CXR features to. Must match ``ehr_hidden_dim``. dropout: Dropout rate. num_labels: Number of output labels. """ SUPPORTED_CXR_BACKBONES = ("resnet18", "resnet34", "resnet50") def __init__( self, ehr_input_dim: int, ehr_hidden_dim: int = 256, ehr_num_layers: int = 2, cxr_backbone: str = "resnet34", cxr_pretrained: bool = True, fusion_hidden_dim: int = 512, projection_dim: int = 256, dropout: float = 0.5, num_labels: int = 1, ) -> None: super().__init__() if projection_dim != ehr_hidden_dim: raise ValueError( "projection_dim must match ehr_hidden_dim for fusion sequencing." ) if cxr_backbone not in self.SUPPORTED_CXR_BACKBONES: raise ValueError( f"Unsupported cxr_backbone: {cxr_backbone}. " f"Supported values: {self.SUPPORTED_CXR_BACKBONES}." ) self.ehr_encoder = nn.LSTM( input_size=ehr_input_dim, hidden_size=ehr_hidden_dim, num_layers=ehr_num_layers, batch_first=True, dropout=dropout if ehr_num_layers > 1 else 0.0, ) self.cxr_encoder, cxr_feature_dim = self._build_cxr_encoder( cxr_backbone=cxr_backbone, cxr_pretrained=cxr_pretrained, ) self.projection = nn.Linear(cxr_feature_dim, projection_dim) self.fusion_lstm = nn.LSTM( input_size=projection_dim, hidden_size=fusion_hidden_dim, num_layers=1, batch_first=True, ) self.dropout_layer = nn.Dropout(dropout) self.classifier = nn.Linear(fusion_hidden_dim, num_labels) def _build_cxr_encoder( self, cxr_backbone: str, cxr_pretrained: bool, ) -> Tuple[nn.Module, int]: """Builds the CXR backbone and returns encoder + feature dimension.""" try: import torchvision.models as tv_models except ImportError as exc: raise ImportError( "torchvision is required to use MedFuse CXR backbones." ) from exc constructor = getattr(tv_models, cxr_backbone) if cxr_pretrained: try: weights = tv_models.get_model_weights(cxr_backbone).DEFAULT backbone = constructor(weights=weights) except Exception: backbone = constructor(pretrained=True) else: try: backbone = constructor(weights=None) except TypeError: backbone = constructor(pretrained=False) if not hasattr(backbone, "fc") or not isinstance(backbone.fc, nn.Linear): raise ValueError( f"Backbone {cxr_backbone} must expose a linear `fc` layer." ) feature_dim = backbone.fc.in_features backbone.fc = nn.Identity() return backbone, feature_dim def _encode_cxr(self, cxr_input: torch.Tensor) -> torch.Tensor: """Encodes CXR images and projects them to the fusion dimension.""" if cxr_input.dim() != 4: raise ValueError( "cxr_input must have shape [batch, channels, height, width]." ) if cxr_input.size(1) == 1: cxr_input = cxr_input.repeat(1, 3, 1, 1) elif cxr_input.size(1) != 3: raise ValueError("cxr_input must have 1 or 3 channels.") cxr_features = self.cxr_encoder(cxr_input) if cxr_features.dim() > 2: cxr_features = torch.flatten(cxr_features, start_dim=1) projected = self.projection(cxr_features) return projected
[docs] def forward( self, ehr_input: torch.Tensor, cxr_input: Optional[torch.Tensor] = None, cxr_mask: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Runs MedFuse fusion. Args: ehr_input: EHR tensor of shape ``[batch, seq_len, ehr_input_dim]``. cxr_input: Optional CXR tensor of shape ``[batch, channels, height, width]``. cxr_mask: Optional tensor of shape ``[batch]`` indicating CXR availability (1 = present, 0 = missing). Returns: Logits tensor of shape ``[batch, num_labels]``. """ if ehr_input.dim() != 3: raise ValueError( "ehr_input must have shape [batch, seq_len, ehr_input_dim]." ) _, (ehr_hidden, _) = self.ehr_encoder(ehr_input) ehr_repr = self.dropout_layer(ehr_hidden[-1]) batch_size = ehr_repr.size(0) sequence_lengths = torch.ones( batch_size, dtype=torch.long, device=ehr_input.device, ) fusion_tokens = [ehr_repr.unsqueeze(1)] if cxr_input is not None: if cxr_input.size(0) != batch_size: raise ValueError( "cxr_input batch size must match ehr_input batch size." ) if cxr_mask is None: cxr_present_mask = torch.ones( batch_size, dtype=torch.bool, device=ehr_input.device, ) else: cxr_present_mask = cxr_mask.to(ehr_input.device).view(-1).bool() if cxr_present_mask.numel() != batch_size: raise ValueError("cxr_mask must have shape [batch].") cxr_projected = torch.zeros( batch_size, self.projection.out_features, dtype=ehr_repr.dtype, device=ehr_input.device, ) if cxr_present_mask.any(): present_cxr = cxr_input[cxr_present_mask] present_projected = self._encode_cxr(present_cxr) cxr_projected[cxr_present_mask] = present_projected sequence_lengths[cxr_present_mask] = 2 fusion_tokens.append(cxr_projected.unsqueeze(1)) fusion_sequence = torch.cat(fusion_tokens, dim=1) packed_sequence = rnn_utils.pack_padded_sequence( fusion_sequence, lengths=sequence_lengths.cpu(), batch_first=True, enforce_sorted=False, ) _, (fusion_hidden, _) = self.fusion_lstm(packed_sequence) fused_repr = self.dropout_layer(fusion_hidden[-1]) logits = self.classifier(fused_repr) return logits
[docs]class MedFuse(BaseModel): """MedFuse model for multi-modal clinical prediction. This model fuses clinical time-series (EHR) and chest X-ray (CXR) representations using an LSTM-based fusion module. It handles missing CXR via variable sequence lengths, where each sample contributes either: - ``[v_ehr]`` (EHR only), or - ``[v_ehr, v_cxr]`` (EHR + CXR). Paper: Hayat, N., Geras, K. J., & Shamout, F. E. (2022). MedFuse: Multi-modal fusion with clinical time-series data and chest X-ray images. MLHC 2022. Args: dataset: Sample dataset used for model configuration. ehr_feature_key: Input key for EHR tensor. Default is ``"ehr"``. cxr_feature_key: Input key for CXR tensor. Default is ``"cxr"``. cxr_mask_key: Optional input key for per-sample CXR availability mask. Default is ``"cxr_mask"``. To surface the mask through the standard dataloader, include ``"cxr_mask": "tensor"`` in your dataset's ``input_schema`` with per-sample availability values (1 = CXR present, 0 = missing). ehr_hidden_dim: Hidden dimension for EHR encoder. ehr_num_layers: Number of EHR LSTM layers. cxr_backbone: CXR ResNet backbone. Default is ``"resnet34"``. cxr_pretrained: Whether to use pretrained CXR backbone. fusion_hidden_dim: Hidden dimension for fusion LSTM. projection_dim: CXR projection output dimension. dropout: Dropout rate. Examples: >>> from pyhealth.datasets import create_sample_dataset, get_dataloader >>> from pyhealth.models import MedFuse >>> samples = [ ... { ... "patient_id": "p0", ... "visit_id": "v0", ... "ehr": torch.randn(5, 10).tolist(), ... "cxr": torch.randn(3, 32, 32).tolist(), ... "label": 1, ... }, ... { ... "patient_id": "p1", ... "visit_id": "v1", ... "ehr": torch.randn(5, 10).tolist(), ... "cxr": torch.randn(3, 32, 32).tolist(), ... "label": 0, ... }, ... ] >>> dataset = create_sample_dataset( ... samples=samples, ... input_schema={"ehr": "tensor", "cxr": "tensor"}, ... output_schema={"label": "binary"}, ... ) >>> loader = get_dataloader(dataset, batch_size=2) >>> model = MedFuse(dataset=dataset, cxr_pretrained=False) >>> batch = next(iter(loader)) >>> output = model(**batch) >>> output["y_prob"].shape torch.Size([2, 1]) """ def __init__( self, dataset: SampleDataset, ehr_feature_key: str = "ehr", cxr_feature_key: str = "cxr", cxr_mask_key: Optional[str] = "cxr_mask", ehr_hidden_dim: int = 256, ehr_num_layers: int = 2, cxr_backbone: str = "resnet34", cxr_pretrained: bool = True, fusion_hidden_dim: int = 512, projection_dim: int = 256, dropout: float = 0.5, ) -> None: super().__init__(dataset=dataset) if len(self.label_keys) != 1: raise ValueError("MedFuse supports exactly one label key.") if ehr_feature_key not in self.feature_keys: raise ValueError( f"ehr_feature_key '{ehr_feature_key}' not found in dataset " f"features: {self.feature_keys}." ) self.ehr_feature_key = ehr_feature_key self.cxr_feature_key = cxr_feature_key self.cxr_mask_key = cxr_mask_key self.label_key = self.label_keys[0] self.mode = self._resolve_mode(self.dataset.output_schema[self.label_key]) self.has_cxr_feature = cxr_feature_key in self.feature_keys self.ehr_input_dim = self._infer_ehr_input_dim(ehr_feature_key) output_size = self.get_output_size() self.layer = MedFuseLayer( ehr_input_dim=self.ehr_input_dim, ehr_hidden_dim=ehr_hidden_dim, ehr_num_layers=ehr_num_layers, cxr_backbone=cxr_backbone, cxr_pretrained=cxr_pretrained, fusion_hidden_dim=fusion_hidden_dim, projection_dim=projection_dim, dropout=dropout, num_labels=output_size, ) def _extract_feature_value( self, feature_key: str, feature: torch.Tensor | tuple[torch.Tensor, ...], ) -> torch.Tensor: """Extracts the semantic ``value`` tensor from a feature payload.""" if isinstance(feature, torch.Tensor): return feature if feature_key not in self.dataset.input_processors: raise ValueError( f"Feature '{feature_key}' is not defined in dataset processors." ) schema = self.dataset.input_processors[feature_key].schema() if "value" in schema: value = feature[schema.index("value")] if isinstance(value, torch.Tensor): return value raise ValueError( f"Feature '{feature_key}' must provide a tensor value in its schema." ) def _infer_ehr_input_dim(self, ehr_feature_key: str) -> int: """Infers EHR input feature dimension from processed dataset samples.""" for sample in self.dataset: if ehr_feature_key not in sample: continue feature = sample[ehr_feature_key] if isinstance(feature, tuple): value = self._extract_feature_value(ehr_feature_key, feature) elif isinstance(feature, torch.Tensor): value = feature else: value = torch.as_tensor(feature) if value.dim() >= 2: return int(value.shape[-1]) raise ValueError( "Unable to infer EHR input dimension. Ensure EHR feature tensors " "have shape [seq_len, feature_dim]." )
[docs] def forward( self, **kwargs: torch.Tensor | tuple[torch.Tensor, ...], ) -> Dict[str, torch.Tensor]: """Forward propagation. Args: **kwargs: Model inputs. Expected keys include: - ``ehr_feature_key`` (required): EHR tensor. - ``cxr_feature_key`` (optional): CXR tensor. - ``cxr_mask_key`` (optional): CXR availability mask. - ``label_key`` (optional): Label tensor for loss computation. Returns: A dictionary containing: - ``logit``: Raw logits. - ``y_prob``: Predicted probabilities. - ``loss``: Loss tensor when labels are provided. - ``y_true``: Ground-truth labels when provided. """ if self.ehr_feature_key not in kwargs: raise ValueError( f"Missing required EHR feature key: '{self.ehr_feature_key}'." ) ehr_input = self._extract_feature_value( self.ehr_feature_key, cast(torch.Tensor | tuple[torch.Tensor, ...], kwargs[self.ehr_feature_key]), ) ehr_input = ehr_input.to(self.device).float() cxr_input: Optional[torch.Tensor] = None if self.cxr_feature_key in kwargs: cxr_input = self._extract_feature_value( self.cxr_feature_key, cast( torch.Tensor | tuple[torch.Tensor, ...], kwargs[self.cxr_feature_key], ), ) cxr_input = cxr_input.to(self.device).float() cxr_mask: Optional[torch.Tensor] = None if self.cxr_mask_key is not None and self.cxr_mask_key in kwargs: raw_mask = kwargs[self.cxr_mask_key] if isinstance(raw_mask, tuple): raise ValueError("cxr_mask must be provided as a tensor.") if isinstance(raw_mask, torch.Tensor): cxr_mask = raw_mask.to(self.device) else: cxr_mask = torch.as_tensor(raw_mask, device=self.device) logits = self.layer( ehr_input=ehr_input, cxr_input=cxr_input, cxr_mask=cxr_mask, ) y_prob = self.prepare_y_prob(logits) results: Dict[str, torch.Tensor] = { "logit": logits, "y_prob": y_prob, } if self.label_key in kwargs: y_true = cast(torch.Tensor, kwargs[self.label_key]).to(self.device) loss = self.get_loss_function()(logits, y_true) results["loss"] = loss results["y_true"] = y_true return results