Source code for pyhealth.models.safedrug

from collections import defaultdict
from copy import deepcopy
from typing import List, Tuple, Dict, Optional
import os

import numpy as np
import rdkit.Chem.BRICS as BRICS
import torch
import torch.nn as nn
from rdkit import Chem

from pyhealth.datasets import SampleEHRDataset
from pyhealth.medcode import ATC
from pyhealth.metrics import ddi_rate_score
from pyhealth.models import BaseModel
from pyhealth.models.utils import get_last_visit
from pyhealth import BASE_CACHE_PATH as CACHE_PATH

class MaskLinear(nn.Module):
    """MaskLinear layer.

    This layer wraps the PyTorch linear layer and adds a hard mask for
    the parameter matrix. It is used in the SafeDrug model.

    Args:
        in_features: input feature size.
        out_features: output feature size.
        bias: whether to use bias. Default is True.
    """

    def __init__(self, in_features: int, out_features: int, bias=True):
        super(MaskLinear, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.weight = nn.Parameter(torch.FloatTensor(in_features, out_features))
        if bias:
            self.bias = nn.Parameter(torch.FloatTensor(out_features))
        else:
            self.register_parameter("bias", None)
        self.reset_parameters()

    def reset_parameters(self):
        stdv = 1.0 / self.weight.size(1) ** 0.5
        self.weight.data.uniform_(-stdv, stdv)
        if self.bias is not None:
            self.bias.data.uniform_(-stdv, stdv)

    def forward(self, input: torch.tensor, mask: torch.tensor) -> torch.tensor:
        """
        Args:
            input: input feature tensor of shape [batch size, ..., input_size].
            mask: mask tensor of shape [input_size, output_size], i.e., the same
                size as the weight matrix.

        Returns:
            Output tensor of shape [batch size, ..., output_size].
        """
        weight = torch.mul(self.weight, mask)
        output = torch.mm(input, weight)
        if self.bias is not None:
            return output + self.bias
        else:
            return output


class MolecularGraphNeuralNetwork(nn.Module):
    """Molecular Graph Neural Network.

    Paper: Masashi Tsubaki et al. Compound-protein interaction
    prediction with end-to-end learning of neural networks for
    graphs and sequences. Bioinformatics, 2019.

    Args:
        num_fingerprints: total number of fingerprints.
        dim: embedding dimension of the fingerprint vectors.
        layer_hidden: number of hidden layers.
    """

    def __init__(self, num_fingerprints, dim, layer_hidden):
        super(MolecularGraphNeuralNetwork, self).__init__()
        self.layer_hidden = layer_hidden
        self.embed_fingerprint = nn.Embedding(num_fingerprints, dim)
        self.W_fingerprint = nn.ModuleList(
            [nn.Linear(dim, dim) for _ in range(layer_hidden)]
        )

    def update(self, matrix, vectors, layer):
        hidden_vectors = torch.relu(self.W_fingerprint[layer](vectors))
        return hidden_vectors + torch.mm(matrix, hidden_vectors)

    def sum(self, vectors, axis):
        sum_vectors = [torch.sum(v, 0) for v in torch.split(vectors, axis)]
        return torch.stack(sum_vectors)

    def mean(self, vectors, axis):
        mean_vectors = [torch.mean(v, 0) for v in torch.split(vectors, axis)]
        return torch.stack(mean_vectors)

    def forward(self, fingerprints, adjacencies, molecular_sizes):
        """
        Args:
            fingerprints: a list of fingerprints
            adjacencies: a list of adjacency matrices
            molecular_sizes: a list of the number of atoms in each molecule
        """
        """MPNN layer (update the fingerprint vectors)."""
        fingerprint_vectors = self.embed_fingerprint(fingerprints)
        for layer in range(self.layer_hidden):
            hs = self.update(adjacencies, fingerprint_vectors, layer)
            # fingerprint_vectors = F.normalize(hs, 2, 1)  # normalize.
            fingerprint_vectors = hs

        """Molecular vector by sum or mean of the fingerprint vectors."""
        molecular_vectors = self.sum(fingerprint_vectors, molecular_sizes)
        # molecular_vectors = self.mean(fingerprint_vectors, molecular_sizes)

        return molecular_vectors


[docs]class SafeDrugLayer(nn.Module): """SafeDrug model. Paper: Chaoqi Yang et al. SafeDrug: Dual Molecular Graph Encoders for Recommending Effective and Safe Drug Combinations. IJCAI 2021. This layer is used in the SafeDrug model. But it can also be used as a standalone layer. Note that we improve the layer a little bit to make it compatible with the package. Original code can be found at https://github.com/ycq091044/SafeDrug/blob/main/src/models.py. Args: hidden_size: hidden feature size. mask_H: the mask matrix H of shape [num_drugs, num_substructures]. ddi_adj: an adjacency tensor of shape [num_drugs, num_drugs]. num_fingerprints: total number of different fingerprints. molecule_set: a list of molecule tuples (A, B, C) of length num_molecules. - A <torch.tensor>: fingerprints of atoms in the molecule - B <torch.tensor>: adjacency matrix of the molecule - C <int>: molecular_size average_projection: a tensor of shape [num_drugs, num_molecules] representing the average projection for aggregating multiple molecules of the same drug into one vector. kp: correcting factor for the proportional signal. Default is 0.5. target_ddi: DDI acceptance rate. Default is 0.08. """ def __init__( self, hidden_size: int, mask_H: torch.Tensor, ddi_adj: torch.Tensor, num_fingerprints: int, molecule_set: List[Tuple], average_projection: torch.Tensor, kp: float = 0.05, target_ddi: float = 0.08, ): super(SafeDrugLayer, self).__init__() self.hidden_size = hidden_size self.kp = kp self.target_ddi = target_ddi self.mask_H = nn.Parameter(mask_H, requires_grad=False) self.ddi_adj = nn.Parameter(ddi_adj, requires_grad=False) # medication space size label_size = mask_H.shape[0] # local bipartite encoder self.bipartite_transform = nn.Linear(hidden_size, mask_H.shape[1]) # self.bipartite_output = MaskLinear(mask_H.shape[1], label_size, False) self.bipartite_output = nn.Linear(mask_H.shape[1], label_size) # global MPNN encoder (add fingerprints and adjacency matrix to parameter list) mpnn_molecule_set = list(zip(*molecule_set)) # process three parts of information fingerprints = torch.cat(mpnn_molecule_set[0]) self.fingerprints = nn.Parameter(fingerprints, requires_grad=False) adjacencies = self.pad(mpnn_molecule_set[1], 0) self.adjacencies = nn.Parameter(adjacencies, requires_grad=False) self.molecule_sizes = mpnn_molecule_set[2] self.average_projection = nn.Parameter(average_projection, requires_grad=False) self.mpnn = MolecularGraphNeuralNetwork( num_fingerprints, hidden_size, layer_hidden=2 ) self.mpnn_output = nn.Linear(label_size, label_size) self.mpnn_layernorm = nn.LayerNorm(label_size) self.test = nn.Linear(hidden_size, label_size) self.loss_fn = nn.BCEWithLogitsLoss()
[docs] def pad(self, matrices, pad_value): """Pads the list of matrices. Padding with a pad_value (e.g., 0) for batch processing. For example, given a list of matrices [A, B, C], we obtain a new matrix [A00, 0B0, 00C], where 0 is the zero (i.e., pad value) matrix. """ shapes = [m.shape for m in matrices] M, N = sum([s[0] for s in shapes]), sum([s[1] for s in shapes]) zeros = torch.FloatTensor(np.zeros((M, N))) pad_matrices = pad_value + zeros i, j = 0, 0 for k, matrix in enumerate(matrices): m, n = shapes[k] pad_matrices[i : i + m, j : j + n] = matrix i += m j += n return pad_matrices
[docs] def calculate_loss( self, logits: torch.Tensor, y_prob: torch.Tensor, labels: torch.Tensor ) -> torch.Tensor: mul_pred_prob = y_prob.T @ y_prob # (voc_size, voc_size) batch_ddi_loss = ( torch.sum(mul_pred_prob.mul(self.ddi_adj)) / self.ddi_adj.shape[0] ** 2 ) y_pred = y_prob.clone().detach().cpu().numpy() y_pred[y_pred >= 0.5] = 1 y_pred[y_pred < 0.5] = 0 y_pred = [np.where(sample == 1)[0] for sample in y_pred] cur_ddi_rate = ddi_rate_score(y_pred, self.ddi_adj.cpu().numpy()) if cur_ddi_rate > self.target_ddi: beta = max(0.0, 1 + (self.target_ddi - cur_ddi_rate) / self.kp) add_loss, beta = batch_ddi_loss, beta else: add_loss, beta = 0, 1 # obtain target, loss, prob, pred bce_loss = self.loss_fn(logits, labels) loss = beta * bce_loss + (1 - beta) * add_loss # loss = bce_loss return loss
[docs] def forward( self, patient_emb: torch.tensor, drugs: torch.tensor, mask: Optional[torch.tensor] = None, ) -> Tuple[torch.tensor, torch.tensor]: """Forward propagation. Args: patient_emb: a tensor of shape [patient, visit, input_size]. drugs: a multihot tensor of shape [patient, num_labels]. mask: an optional tensor of shape [patient, visit] where 1 indicates valid visits and 0 indicates invalid visits. Returns: loss: a scalar tensor representing the loss. y_prob: a tensor of shape [patient, num_labels] representing the probability of each drug. """ if mask is None: mask = torch.ones_like(patient_emb[:, :, 0]) query = get_last_visit(patient_emb, mask) # (batch, dim) # MPNN Encoder MPNN_emb = self.mpnn( self.fingerprints, self.adjacencies, self.molecule_sizes ) # (#molecule, hidden_size) MPNN_emb = torch.mm(self.average_projection, MPNN_emb) # (#med, hidden_size) MPNN_match = torch.sigmoid(torch.mm(query, MPNN_emb.T)) # (patient, #med) MPNN_att = self.mpnn_layernorm( MPNN_match + self.mpnn_output(MPNN_match) ) # (batch, #med) # Bipartite Encoder (use the bipartite encoder only for now) bipartite_emb = torch.sigmoid(self.bipartite_transform(query)) # (batch, dim) bipartite_att = self.bipartite_output( bipartite_emb ) # (batch, #med) # combine logits = bipartite_att * MPNN_att # calculate the ddi_loss by PID stragegy and add to final loss y_prob = torch.sigmoid(logits) loss = self.calculate_loss(logits, y_prob, drugs) return loss, y_prob
[docs]class SafeDrug(BaseModel): """SafeDrug model. Paper: Chaoqi Yang et al. SafeDrug: Dual Molecular Graph Encoders for Recommending Effective and Safe Drug Combinations. IJCAI 2021. Note: This model is only for medication prediction which takes conditions and procedures as feature_keys, and drugs as label_key. It only operates on the visit level. Note: This model only accepts ATC level 3 as medication codes. Args: dataset: the dataset to train the model. It is used to query certain information such as the set of all tokens. embedding_dim: the embedding dimension. Default is 128. hidden_dim: the hidden dimension. Default is 128. num_layers: the number of layers used in RNN. Default is 1. dropout: the dropout rate. Default is 0.5. **kwargs: other parameters for the SafeDrug layer. """ def __init__( self, dataset: SampleEHRDataset, embedding_dim: int = 128, hidden_dim: int = 128, num_layers: int = 1, dropout: float = 0.5, **kwargs, ): super(SafeDrug, self).__init__( dataset=dataset, feature_keys=["conditions", "procedures"], label_key="drugs", mode="multilabel", ) self.embedding_dim = embedding_dim self.hidden_dim = hidden_dim self.num_layers = num_layers self.dropout = dropout self.feat_tokenizers = self.get_feature_tokenizers() self.label_tokenizer = self.get_label_tokenizer() self.embeddings = self.get_embedding_layers(self.feat_tokenizers, embedding_dim) # drug space size self.label_size = self.label_tokenizer.get_vocabulary_size() self.all_smiles_list = self.generate_smiles_list() mask_H = self.generate_mask_H() ( molecule_set, num_fingerprints, average_projection, ) = self.generate_molecule_info() ddi_adj = self.generate_ddi_adj() self.cond_rnn = nn.GRU( embedding_dim, hidden_dim, num_layers=num_layers, dropout=dropout if num_layers > 1 else 0, batch_first=True, ) self.proc_rnn = nn.GRU( embedding_dim, hidden_dim, num_layers=num_layers, dropout=dropout if num_layers > 1 else 0, batch_first=True, ) self.query = nn.Sequential( nn.ReLU(), nn.Linear(hidden_dim * 2, hidden_dim), ) # validate kwargs for GAMENet layer if "hidden_size" in kwargs: raise ValueError("hidden_size is determined by hidden_dim") if "mask_H" in kwargs: raise ValueError("mask_H is determined by the dataset") if "ddi_adj" in kwargs: raise ValueError("ddi_adj is determined by the dataset") if "num_fingerprints" in kwargs: raise ValueError("num_fingerprints is determined by the dataset") if "molecule_set" in kwargs: raise ValueError("molecule_set is determined by the dataset") if "average_projection" in kwargs: raise ValueError("average_projection is determined by the dataset") self.safedrug = SafeDrugLayer( hidden_size=hidden_dim, mask_H=mask_H, ddi_adj=ddi_adj, num_fingerprints=num_fingerprints, molecule_set=molecule_set, average_projection=average_projection, **kwargs, ) # save ddi adj ddi_adj = self.generate_ddi_adj() np.save(os.path.join(CACHE_PATH, "ddi_adj.npy"), ddi_adj.numpy())
[docs] def generate_ddi_adj(self) -> torch.tensor: """Generates the DDI graph adjacency matrix.""" atc = ATC() ddi = atc.get_ddi(gamenet_ddi=True) label_size = self.label_tokenizer.get_vocabulary_size() vocab_to_index = self.label_tokenizer.vocabulary ddi_adj = np.zeros((label_size, label_size)) ddi_atc3 = [ [ATC.convert(l[0], level=3), ATC.convert(l[1], level=3)] for l in ddi ] for atc_i, atc_j in ddi_atc3: if atc_i in vocab_to_index and atc_j in vocab_to_index: ddi_adj[vocab_to_index(atc_i), vocab_to_index(atc_j)] = 1 ddi_adj[vocab_to_index(atc_j), vocab_to_index(atc_i)] = 1 ddi_adj = torch.FloatTensor(ddi_adj) return ddi_adj
[docs] def generate_smiles_list(self) -> List[List[str]]: """Generates the list of SMILES strings.""" atc3_to_smiles = {} atc = ATC() for code in atc.graph.nodes: if len(code) != 7: continue code_atc3 = ATC.convert(code, level=3) smiles = atc.graph.nodes[code]["smiles"] if smiles != smiles: continue atc3_to_smiles[code_atc3] = atc3_to_smiles.get(code_atc3, []) + [smiles] # just take first one for computational efficiency atc3_to_smiles = {k: v[:1] for k, v in atc3_to_smiles.items()} all_smiles_list = [[] for _ in range(self.label_size)] vocab_to_index = self.label_tokenizer.vocabulary for atc3, smiles_list in atc3_to_smiles.items(): if atc3 in vocab_to_index: index = vocab_to_index(atc3) all_smiles_list[index] += smiles_list return all_smiles_list
[docs] def generate_mask_H(self) -> torch.tensor: """Generates the molecular segmentation mask H.""" all_substructures_list = [[] for _ in range(self.label_size)] for index, smiles_list in enumerate(self.all_smiles_list): for smiles in smiles_list: mol = Chem.MolFromSmiles(smiles) if mol is None: continue substructures = BRICS.BRICSDecompose(mol) all_substructures_list[index] += substructures # all segment set substructures_set = list(set(sum(all_substructures_list, []))) # mask_H mask_H = np.zeros((self.label_size, len(substructures_set))) for index, substructures in enumerate(all_substructures_list): for s in substructures: mask_H[index, substructures_set.index(s)] = 1 mask_H = torch.FloatTensor(mask_H) return mask_H
[docs] def generate_molecule_info(self, radius: int = 1): """Generates the molecule information.""" def create_atoms(mol, atom2idx): """Transform the atom types in a molecule (e.g., H, C, and O) into the indices (e.g., H=0, C=1, and O=2). Note that each atom index considers the aromaticity. """ atoms = [a.GetSymbol() for a in mol.GetAtoms()] for a in mol.GetAromaticAtoms(): i = a.GetIdx() atoms[i] = (atoms[i], "aromatic") atoms = [atom2idx[a] for a in atoms] return np.array(atoms) def create_ijbonddict(mol, bond2idx): """Create a dictionary, in which each key is a node ID and each value is the tuples of its neighboring node and chemical bond (e.g., single and double) IDs. """ i_jbond_dict = defaultdict(lambda: []) for b in mol.GetBonds(): i, j = b.GetBeginAtomIdx(), b.GetEndAtomIdx() bond = bond2idx[str(b.GetBondType())] i_jbond_dict[i].append((j, bond)) i_jbond_dict[j].append((i, bond)) return i_jbond_dict def extract_fingerprints(r, atoms, i_jbond_dict, fingerprint2idx, edge2idx): """Extract the fingerprints from a molecular graph based on Weisfeiler-Lehman algorithm. """ nodes = [fingerprint2idx[a] for a in atoms] i_jedge_dict = i_jbond_dict for _ in range(r): """Update each node ID considering its neighboring nodes and edges. The updated node IDs are the fingerprint IDs. """ nodes_ = deepcopy(nodes) for i, j_edge in i_jedge_dict.items(): neighbors = [(nodes[j], edge) for j, edge in j_edge] fingerprint = (nodes[i], tuple(sorted(neighbors))) nodes_[i] = fingerprint2idx[fingerprint] """Also update each edge ID considering its two nodes on both sides. """ i_jedge_dict_ = defaultdict(list) for i, j_edge in i_jedge_dict.items(): for j, edge in j_edge: both_side = tuple(sorted((nodes[i], nodes[j]))) edge = edge2idx[(both_side, edge)] i_jedge_dict_[i].append((j, edge)) nodes = deepcopy(nodes_) i_jedge_dict = deepcopy(i_jedge_dict_) del nodes_, i_jedge_dict_ return np.array(nodes) atom2idx = defaultdict(lambda: len(atom2idx)) bond2idx = defaultdict(lambda: len(bond2idx)) fingerprint2idx = defaultdict(lambda: len(fingerprint2idx)) edge2idx = defaultdict(lambda: len(edge2idx)) molecule_set, average_index = [], [] for smiles_list in self.all_smiles_list: """Create each data with the above defined functions.""" counter = 0 # counter how many drugs are under that ATC-3 for smiles in smiles_list: mol = Chem.MolFromSmiles(smiles) if mol is None: continue mol = Chem.AddHs(mol) atoms = create_atoms(mol, atom2idx) molecular_size = len(atoms) i_jbond_dict = create_ijbonddict(mol, bond2idx) fingerprints = extract_fingerprints( radius, atoms, i_jbond_dict, fingerprint2idx, edge2idx ) adjacency = Chem.GetAdjacencyMatrix(mol) """Transform the above each data of numpy to pytorch tensor.""" fingerprints = torch.LongTensor(fingerprints) adjacency = torch.FloatTensor(adjacency) molecule_set.append((fingerprints, adjacency, molecular_size)) counter += 1 average_index.append(counter) num_fingerprints = len(fingerprint2idx) # transform into projection matrix n_col = sum(average_index) n_row = len(average_index) average_projection = np.zeros((n_row, n_col)) col_counter = 0 for i, item in enumerate(average_index): if item > 0: average_projection[i, col_counter : col_counter + item] = 1 / item col_counter += item average_projection = torch.FloatTensor(average_projection) return molecule_set, num_fingerprints, average_projection
[docs] def forward( self, conditions: List[List[List[str]]], procedures: List[List[List[str]]], drugs: List[List[str]], **kwargs, ) -> Dict[str, torch.Tensor]: """Forward propagation. Args: conditions: a nested list in three levels [patient, visit, condition]. procedures: a nested list in three levels [patient, visit, procedure]. drugs: a nested list in two levels [patient, drug]. Returns: A dictionary with the following keys: loss: a scalar tensor representing the loss. y_prob: a tensor of shape [patient, visit, num_labels] representing the probability of each drug. y_true: a tensor of shape [patient, visit, num_labels] representing the ground truth of each drug. """ conditions = self.feat_tokenizers["conditions"].batch_encode_3d(conditions) # (patient, visit, code) conditions = torch.tensor(conditions, dtype=torch.long, device=self.device) # (patient, visit, code, embedding_dim) conditions = self.embeddings["conditions"](conditions) # (patient, visit, embedding_dim) conditions = torch.sum(conditions, dim=2) # (batch, visit, hidden_size) conditions, _ = self.cond_rnn(conditions) procedures = self.feat_tokenizers["procedures"].batch_encode_3d(procedures) # (patient, visit, code) procedures = torch.tensor(procedures, dtype=torch.long, device=self.device) # (patient, visit, code, embedding_dim) procedures = self.embeddings["procedures"](procedures) # (patient, visit, embedding_dim) procedures = torch.sum(procedures, dim=2) # (batch, visit, hidden_size) procedures, _ = self.proc_rnn(procedures) # (batch, visit, 2 * hidden_size) patient_emb = torch.cat([conditions, procedures], dim=-1) # (batch, visit, hidden_size) patient_emb = self.query(patient_emb) # get mask mask = torch.sum(conditions, dim=2) != 0 drugs = self.prepare_labels(drugs, self.label_tokenizer) loss, y_prob = self.safedrug(patient_emb, drugs, mask) return { "loss": loss, "y_prob": y_prob, "y_true": drugs, }