import os
from typing import Any, Dict, List, Optional, Tuple, Union
import torch
import torch.nn as nn
import numpy as np
from pyhealth.datasets import SampleDataset
from pyhealth.models import BaseModel
from pyhealth.models.embedding import EmbeddingModel
from pyhealth.processors.base_processor import FeatureProcessor
from pyhealth.processors import (
SequenceProcessor, StageNetProcessor, StageNetTensorProcessor, TimeseriesProcessor, TensorProcessor, MultiHotProcessor
)
from pyhealth.processors.base_processor import FeatureProcessor
from pyhealth.models.utils import get_last_visit
from pyhealth import BASE_CACHE_PATH as CACHE_PATH
from pyhealth.medcode import ATC
[docs]class MICRONLayer(nn.Module):
"""MICRON layer.
Paper: Chaoqi Yang et al. Change Matters: Medication Change Prediction
with Recurrent Residual Networks. IJCAI 2021.
This layer is used in the MICRON model. But it can also be used as a
standalone layer.
Args:
input_size: input feature size.
hidden_size: hidden feature size.
num_drugs: total number of drugs to recommend.
lam: regularization parameter for the reconstruction loss. Default is 0.1.
Examples:
>>> from pyhealth.models import MICRONLayer
>>> patient_emb = torch.randn(3, 5, 32) # [patient, visit, input_size]
>>> drugs = torch.randint(0, 2, (3, 50)).float()
>>> layer = MICRONLayer(32, 64, 50)
>>> loss, y_prob = layer(patient_emb, drugs)
>>> loss.shape
torch.Size([])
>>> y_prob.shape # Probabilities for each drug
torch.Size([3, 50])
"""
def __init__(
self, input_size: int, hidden_size: int, num_drugs: int, lam: float = 0.1
):
super(MICRONLayer, self).__init__()
self.input_size = input_size
self.hidden_size = hidden_size
self.num_labels = num_drugs
self.lam = lam
self.health_net = nn.Linear(input_size, hidden_size)
self.prescription_net = nn.Linear(hidden_size, hidden_size)
self.fc = nn.Linear(hidden_size, num_drugs)
self.bce_loss_fn = nn.BCEWithLogitsLoss()
[docs] @staticmethod
def compute_reconstruction_loss(
logits: torch.tensor, logits_residual: torch.tensor, mask: torch.tensor
) -> torch.tensor:
"""Compute reconstruction loss between predicted and actual medication changes.
The reconstruction loss measures how well the model captures medication changes
between consecutive visits by comparing the predicted changes (through residual
connections) with actual changes in prescriptions.
Args:
logits (torch.tensor): Raw logits for medication predictions across all visits.
logits_residual (torch.tensor): Residual logits representing predicted changes.
mask (torch.tensor): Boolean mask indicating valid visits.
Returns:
torch.tensor: Mean squared reconstruction loss value.
"""
rec_loss = torch.mean(
torch.square(
torch.sigmoid(logits[:, 1:, :])
- torch.sigmoid(logits[:, :-1, :] + logits_residual)
)
* mask[:, 1:].unsqueeze(2)
)
return rec_loss
[docs] def forward(
self,
patient_emb: torch.tensor,
drugs: torch.tensor,
mask: Optional[torch.tensor] = None,
) -> Tuple[torch.tensor, torch.tensor]:
"""Forward propagation.
Args:
patient_emb: a tensor of shape [patient, visit, input_size].
drugs: a multihot tensor of shape [patient, num_labels].
mask: an optional tensor of shape [patient, visit] where
1 indicates valid visits and 0 indicates invalid visits.
Returns:
loss: a scalar tensor representing the loss.
y_prob: a tensor of shape [patient, num_labels] representing
the probability of each drug.
"""
if mask is None:
mask = torch.ones_like(patient_emb[:, :, 0])
# (patient, visit, hidden_size)
health_rep = self.health_net(patient_emb)
drug_rep = self.prescription_net(health_rep)
logits = self.fc(drug_rep)
logits_last_visit = get_last_visit(logits, mask)
bce_loss = self.bce_loss_fn(logits_last_visit, drugs)
# (batch, visit-1, input_size)
health_rep_last = health_rep[:, :-1, :]
# (batch, visit-1, input_size)
health_rep_cur = health_rep[:, 1:, :]
# (batch, visit-1, input_size)
health_rep_residual = health_rep_cur - health_rep_last
drug_rep_residual = self.prescription_net(health_rep_residual)
logits_residual = self.fc(drug_rep_residual)
rec_loss = self.compute_reconstruction_loss(logits, logits_residual, mask)
loss = bce_loss + self.lam * rec_loss
y_prob = torch.sigmoid(logits_last_visit)
return loss, y_prob
[docs]class MICRON(BaseModel):
"""MICRON model (PyHealth 2.0 compatible).
Paper: Chaoqi Yang et al. Change Matters: Medication Change Prediction
with Recurrent Residual Networks. IJCAI 2021.
This model is for medication prediction using PyHealth 2.0 SampleDataset and processors.
It expects input_schema to include 'conditions' and 'procedures' as sequence features,
and output_schema to include 'drugs' as a multilabel/multihot feature.
Args:
dataset (SampleDataset): Dataset object containing patient records and schema information.
embedding_dim (int, optional): Dimension for feature embeddings. Defaults to 128.
hidden_dim (int, optional): Dimension for hidden layers. Defaults to 128.
**kwargs: Additional parameters passed to the MICRON layer (e.g., lam for loss weighting).
Attributes:
embedding_model (EmbeddingModel): Handles embedding of input features.
feature_processors (dict): Maps feature keys to their respective processors.
micron (MICRONLayer): Core MICRON layer for medication prediction.
Note:
The model expects specific schema configurations:
- input_schema should include 'conditions' and 'procedures' as sequence features
- output_schema should include 'drugs' as a multilabel/multihot feature
Example:
>>> from pyhealth.datasets import create_sample_dataset
>>> samples = [
... {
... "patient_id": "patient-0",
... "visit_id": "visit-0",
... "conditions": ["E11.9", "I10"],
... "procedures": ["0DJD8ZZ"],
... "drugs": ["metformin", "lisinopril"]
... }
... ]
>>> dataset = create_sample_dataset(
... samples=samples,
... input_schema={"conditions": "sequence", "procedures": "sequence"},
... output_schema={"drugs": "multilabel"},
... dataset_name="test",
... )
>>> model = MICRON(dataset=dataset)
"""
def __init__(
self,
dataset: SampleDataset,
embedding_dim: int = 128,
hidden_dim: int = 128,
**kwargs
):
super().__init__(dataset=dataset)
self.embedding_dim = embedding_dim
self.hidden_dim = hidden_dim
assert len(self.label_keys) == 1, "Only one label key is supported."
self.label_key = self.label_keys[0]
self.mode = self.dataset.output_schema[self.label_key]
self.embedding_model = EmbeddingModel(dataset, embedding_dim)
self.feature_processors = {
feature_key: self.dataset.input_processors[feature_key]
for feature_key in self.feature_keys
}
# validate kwargs for MICRON layer
if "input_size" in kwargs:
raise ValueError("input_size is determined by embedding_dim and number of features")
if "hidden_size" in kwargs:
raise ValueError("hidden_size is determined by hidden_dim")
if "num_drugs" in kwargs:
raise ValueError("num_drugs is determined by the dataset")
# Get label processor
label_processor = self.dataset.output_processors[self.label_key]
# Get vocabulary size using the standard size() method
if not hasattr(label_processor, "size"):
raise ValueError(
"Label processor must implement size() method. "
"The processor type is: " + type(label_processor).__name__
)
num_drugs = label_processor.size()
if num_drugs == 0:
raise ValueError("Label processor returned 0 size")
self.micron = MICRONLayer(
input_size=embedding_dim * len(self.feature_keys),
hidden_size=hidden_dim,
num_drugs=num_drugs,
**kwargs
)
# save ddi adjacency matrix for later use
ddi_adj = self.generate_ddi_adj()
np.save(os.path.join(CACHE_PATH, "ddi_adj.npy"), ddi_adj)
@staticmethod
def _split_temporal(feature):
if isinstance(feature, tuple) and len(feature) == 2:
return feature
return None, feature
def _ensure_tensor(self, feature_key: str, value) -> torch.Tensor:
if isinstance(value, torch.Tensor):
return value
processor = self.feature_processors[feature_key]
if isinstance(processor, (SequenceProcessor, StageNetProcessor)):
return torch.tensor(value, dtype=torch.long)
return torch.tensor(value, dtype=torch.float)
def _pool_embedding(self, x: torch.Tensor) -> torch.Tensor:
if x.dim() == 4:
x = x.sum(dim=2)
if x.dim() == 2:
x = x.unsqueeze(1)
# Make sure temporal dimension (dim=1) matches the longest sequence
if x.size(1) == 1:
# Repeat to handle shorter sequences
x = x.repeat(1, 2, 1)
return x
def _create_mask(self, feature_key: str, value: torch.Tensor) -> torch.Tensor:
processor = self.feature_processors[feature_key]
if isinstance(processor, SequenceProcessor):
mask = value != 0
elif isinstance(processor, StageNetProcessor):
if value.dim() >= 3:
mask = torch.any(value != 0, dim=-1)
else:
mask = value != 0
elif isinstance(processor, (TimeseriesProcessor, StageNetTensorProcessor)):
if value.dim() >= 3:
mask = torch.any(torch.abs(value) > 0, dim=-1)
elif value.dim() == 2:
mask = torch.any(torch.abs(value) > 0, dim=-1, keepdim=True)
else:
mask = torch.ones(
value.size(0),
1,
dtype=torch.bool,
device=value.device,
)
elif isinstance(processor, (TensorProcessor, MultiHotProcessor)):
mask = torch.ones(
value.size(0),
1,
dtype=torch.bool,
device=value.device,
)
else:
if value.dim() >= 2:
mask = torch.any(value != 0, dim=-1)
else:
mask = torch.ones(
value.size(0),
1,
dtype=torch.bool,
device=value.device,
)
if mask.dim() == 1:
mask = mask.unsqueeze(1)
mask = mask.bool()
if mask.dim() == 2:
invalid_rows = ~mask.any(dim=1)
if invalid_rows.any():
mask[invalid_rows, 0] = True
return mask
[docs] def forward(self, **kwargs) -> Dict[str, torch.Tensor]:
"""Forward propagation with PyHealth 2.0 inputs.
Args:
**kwargs: Keyword arguments that include every feature key defined in
the dataset schema plus the label key. Additional arguments:
- register_hook (bool): whether to register attention hooks
- embed (bool): whether to return embeddings in output
Returns:
Dict[str, torch.Tensor]: Prediction dictionary containing the loss,
probabilities, labels, and optionally embeddings.
"""
patient_emb = []
embedding_inputs: Dict[str, torch.Tensor] = {}
masks: Dict[str, torch.Tensor] = {}
for feature_key in self.feature_keys:
_, value = self._split_temporal(kwargs[feature_key])
value_tensor = self._ensure_tensor(feature_key, value).to(self.device)
embedding_inputs[feature_key] = value_tensor
masks[feature_key] = self._create_mask(feature_key, value_tensor).to(self.device)
embedded = self.embedding_model(embedding_inputs)
for feature_key in self.feature_keys:
x = embedded[feature_key]
mask = masks[feature_key]
x = self._pool_embedding(x)
patient_emb.append(x)
# Concatenate along last dim: [batch, seq_len, embedding_dim * n_features]
patient_emb = torch.cat(patient_emb, dim=2)
# Use visit-level mask from first feature (or combine as needed)
mask = masks[self.feature_keys[0]]
# Labels: expects multihot [batch, num_labels]
y_true = kwargs[self.label_key].to(self.device)
loss, y_prob = self.micron(patient_emb, y_true, mask)
results = {"loss": loss, "y_prob": y_prob, "y_true": y_true}
if kwargs.get("embed", False):
results["embed"] = patient_emb
return results
[docs] def generate_ddi_adj(self) -> torch.Tensor:
"""Generates the drug-drug interaction (DDI) graph adjacency matrix using PyHealth 2.0 label processor."""
atc = ATC()
ddi = atc.get_ddi(gamenet_ddi=True)
label_processor = self.dataset.output_processors[self.label_key]
label_size = label_processor.size()
vocab_to_index = label_processor.label_vocab
ddi_adj = np.zeros((label_size, label_size))
ddi_atc3 = [
[ATC.convert(l[0], level=3), ATC.convert(l[1], level=3)] for l in ddi
]
for atc_i, atc_j in ddi_atc3:
if atc_i in vocab_to_index and atc_j in vocab_to_index:
ddi_adj[vocab_to_index[atc_i], vocab_to_index[atc_j]] = 1
ddi_adj[vocab_to_index[atc_j], vocab_to_index[atc_i]] = 1
return torch.tensor(ddi_adj, dtype=torch.float)