Source code for pyhealth.models.utils

from typing import List

import torch


def batch_to_multihot(label: List[List[int]], num_labels: int) -> torch.tensor:
    """Converts label to multihot format.

    Args:
        label: [batch size, *]
        num_labels: total number of labels

    Returns:
        multihot: [batch size, num_labels]
    """
    multihot = torch.zeros((len(label), num_labels))
    for i, l in enumerate(label):
        multihot[i, l] = 1
    return multihot


[docs]def get_last_visit(hidden_states, mask): """Gets the last visit from the sequence model. Args: hidden_states: [batch size, seq len, hidden_size] mask: [batch size, seq len] Returns: last_visit: [batch size, hidden_size] """ if mask is None: return hidden_states[:, -1, :] else: mask = mask.long() last_visit = torch.sum(mask, 1) - 1 # Clamp to 0 so that samples with an all-zero mask (no valid # visits) fall back to the first timestep instead of producing # a negative index that would crash torch.gather. last_visit = last_visit.clamp(min=0) last_visit = last_visit.unsqueeze(-1) last_visit = last_visit.expand(-1, hidden_states.shape[1] * hidden_states.shape[2]) last_visit = torch.reshape(last_visit, hidden_states.shape) last_hidden_states = torch.gather(hidden_states, 1, last_visit) last_hidden_state = last_hidden_states[:, 0, :] return last_hidden_state
def get_rightmost_masked_timestep(hidden_states, mask): """Gather hidden state at the last True position in ``mask`` per row. Unlike :func:`get_last_visit`, this does **not** assume valid tokens form a contiguous prefix; it picks the maximum index where ``mask`` is True. Use for MPF / CEHR layouts where padding can appear between boundary tokens. Args: hidden_states: ``[batch, seq_len, hidden_size]``. mask: ``[batch, seq_len]`` bool. Returns: Tensor ``[batch, hidden_size]``. """ if mask is None: return hidden_states[:, -1, :] batch, seq_len, hidden = hidden_states.shape device = hidden_states.device idx = torch.arange(seq_len, device=device, dtype=torch.long).unsqueeze(0).expand( batch, -1 ) idx_m = torch.where(mask, idx, torch.full_like(idx, -1)) last_idx = idx_m.max(dim=1).values.clamp(min=0) last_idx = last_idx.view(batch, 1, 1).expand(batch, 1, hidden) gathered = torch.gather(hidden_states, 1, last_idx) return gathered[:, 0, :]