Source code for pyhealth.models.medlink.model

from __future__ import annotations

from typing import Dict, List, Any, Optional, Sequence, Tuple

import torch
import torch.nn as nn
import torch.nn.functional as F
import tqdm
from torch.nn.utils.rnn import pad_sequence

from ...datasets import SampleDataset
from ..base_model import BaseModel
from ..transformer import TransformerLayer
from ...processors import SequenceProcessor

from ..embedding import init_embedding_with_pretrained


def _build_shared_vocab(
    q_processor: SequenceProcessor,
    d_processor: SequenceProcessor,
    pad_token: str = "<pad>",
    unk_token: str = "<unk>",
) -> Dict[str, int]:
    """Build a shared token->index mapping from two fitted SequenceProcessors.

    The returned vocabulary is deterministic (sorted token order) and always
    includes `pad_token` and `unk_token`.
    """

    vocab: Dict[str, int] = {pad_token: 0, unk_token: 1}

    tokens = set(str(t) for t in q_processor.code_vocab.keys()) | set(
        str(t) for t in d_processor.code_vocab.keys()
    )
    tokens.discard(pad_token)
    tokens.discard(unk_token)

    for t in sorted(tokens):
        if t not in vocab:
            vocab[t] = len(vocab)
    return vocab


def _build_index_remap(
    processor: SequenceProcessor,
    shared_vocab: Dict[str, int],
    unk_idx: int,
) -> torch.Tensor:
    """Build a dense remap tensor old_idx -> shared_idx."""

    size = len(processor.code_vocab)
    remap = torch.full((size,), unk_idx, dtype=torch.long)
    for tok, old_idx in processor.code_vocab.items():
        tok_s = str(tok)
        remap[old_idx] = shared_vocab.get(tok_s, unk_idx)
    return remap


def _to_index_tensor(
    seq: Any,
    processor: SequenceProcessor,
) -> torch.Tensor:
    """Converts a single sequence to an index tensor using the fitted processor."""
    if isinstance(seq, torch.Tensor):
        return seq.long()
    if isinstance(seq, (list, tuple)):
        return processor.process(seq)
    # single token
    return processor.process([seq])


def _pad_and_remap(
    sequences: Sequence[Any],
    processor: SequenceProcessor,
    remap: torch.Tensor,
    pad_value: int = 0,
    device: Optional[torch.device] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
    """Pads a batch of sequences and remaps indices into the shared vocab.

    Returns:
        ids_shared: LongTensor [B, L]
        mask: BoolTensor [B, L] where True indicates valid token positions.
    """

    ids = [_to_index_tensor(s, processor) for s in sequences]
    ids_padded = pad_sequence(ids, batch_first=True, padding_value=pad_value)
    if device is not None:
        ids_padded = ids_padded.to(device)
        remap = remap.to(device)
    ids_shared = remap[ids_padded]
    mask = ids_shared != 0
    return ids_shared, mask


class AdmissionPrediction(nn.Module):
    """Admission prediction module used by MedLink.

    This is a lightly-adapted version of the original MedLink implementation,
    refactored to work with PyHealth 2.0 processors (i.e., indexed tensors).
    """

    def __init__(
        self,
        code_vocab: Dict[str, int],
        embedding_dim: int,
        heads: int = 2,
        dropout: float = 0.5,
        num_layers: int = 1,
        pretrained_emb_path: Optional[str] = None,
        freeze_pretrained: bool = False,
    ):
        super().__init__()
        self.code_vocab = code_vocab
        self.vocab_size = len(code_vocab)
        self.pad_idx = code_vocab.get("<pad>", 0)
        self.cls_idx = code_vocab.get("<cls>")

        self.embedding = nn.Embedding(
            num_embeddings=self.vocab_size,
            embedding_dim=embedding_dim,
            padding_idx=self.pad_idx,
        )
        if pretrained_emb_path:
            init_embedding_with_pretrained(
                self.embedding,
                code_vocab,
                pretrained_emb_path,
                embedding_dim=embedding_dim,
                freeze=freeze_pretrained,
            )

        self.encoder = TransformerLayer(
            feature_size=embedding_dim,
            heads=heads,
            dropout=dropout,
            num_layers=num_layers,
        )
        self.criterion = nn.BCEWithLogitsLoss()

    def _multi_hot(self, input_ids: torch.Tensor) -> torch.Tensor:
        """Builds a multi-hot label vector per sample."""

        # input_ids: [B, L]
        bsz = input_ids.size(0)
        out = torch.zeros(bsz, self.vocab_size, device=input_ids.device, dtype=torch.float)
        src = torch.ones_like(input_ids, dtype=torch.float)
        out.scatter_add_(1, input_ids, src)
        out = (out > 0).float()
        # Remove special tokens from labels.
        if self.pad_idx is not None:
            out[:, self.pad_idx] = 0
        if self.cls_idx is not None:
            out[:, self.cls_idx] = 0
        return out

    def get_loss(self, logits: torch.Tensor, target_multi_hot: torch.Tensor) -> torch.Tensor:
        true_batch_size = min(logits.shape[0], target_multi_hot.shape[0])
        return self.criterion(logits[:true_batch_size], target_multi_hot[:true_batch_size])

    def forward(self, input_ids: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """Compute vocabulary logits and target multi-hot labels.

        Args:
            input_ids: LongTensor [B, L] in shared vocabulary indices.

        Returns:
            logits: FloatTensor [B, V]
            target: FloatTensor [B, V] multi-hot labels.
        """

        mask = input_ids != self.pad_idx
        x = self.embedding(input_ids)
        x, _ = self.encoder(x, mask=mask)

        # Use embedding table as vocabulary embedding.
        vocab_emb = self.embedding.weight  # [V, D]
        logits = torch.matmul(x, vocab_emb.T)  # [B, L, V]
        logits = logits.masked_fill(~mask.unsqueeze(-1), -1e9)
        logits = logits.max(dim=1).values  # [B, V]

        target = self._multi_hot(input_ids)
        return logits, target





if __name__ == "__main__":
    # Minimal smoke-test matching the public example script.
    from pyhealth.datasets import MIMIC3Dataset
    from pyhealth.models.medlink import (
        convert_to_ir_format,
        get_train_dataloader,
        tvt_split,
    )
    from pyhealth.tasks import PatientLinkageMIMIC3Task

    base_dataset = MIMIC3Dataset(
        root="/srv/local/data/physionet.org/files/mimiciii/1.4",
        tables=["DIAGNOSES_ICD", "ADMISSIONS", "PATIENTS"], # added tables for task class
        code_mapping={"ICD9CM": ("CCSCM", {})},
        dev=False,
        refresh_cache=False,
    )

    task = PatientLinkageMIMIC3Task()
    sample_dataset = base_dataset.set_task(task)
    corpus, queries, qrels, *_ = convert_to_ir_format(sample_dataset.samples)
    tr_queries, _, _, tr_qrels, _, _ = tvt_split(queries, qrels)
    train_dataloader = get_train_dataloader(corpus, tr_queries, tr_qrels, batch_size=4)
    batch = next(iter(train_dataloader))
    model = MedLink(dataset=sample_dataset, feature_keys=["conditions"], embedding_dim=32)
    out = model(**batch)
    print("loss:", out["loss"].item())