Source code for pyhealth.models.molerec

import torch
import os
import math
import pkg_resources
import numpy as np
from typing import Any, Dict, List, Tuple, Optional, Union
from rdkit import Chem
from torch.nn.functional import binary_cross_entropy_with_logits
from torch.nn.functional import multilabel_margin_loss

from pyhealth.models import BaseModel
from pyhealth.models.utils import get_last_visit
from pyhealth.models.utils import batch_to_multihot
from pyhealth.metrics import ddi_rate_score
from pyhealth.medcode import ATC
from pyhealth.datasets import SampleEHRDataset

from pyhealth import BASE_CACHE_PATH as CACHE_PATH

def graph_batch_from_smiles(smiles_list, device=torch.device("cpu")):
    edge_idxes, edge_feats, node_feats, lstnode, batch = [], [], [], 0, []
    graphs = [smiles2graph(x) for x in smiles_list]

    for idx, graph in enumerate(graphs):
        edge_idxes.append(graph["edge_index"] + lstnode)
        edge_feats.append(graph["edge_feat"])
        node_feats.append(graph["node_feat"])
        lstnode += graph["num_nodes"]
        batch.append(np.ones(graph["num_nodes"], dtype=np.int64) * idx)

    result = {
        "edge_index": np.concatenate(edge_idxes, axis=-1),
        "edge_attr": np.concatenate(edge_feats, axis=0),
        "batch": np.concatenate(batch, axis=0),
        "x": np.concatenate(node_feats, axis=0),
    }
    result = {k: torch.from_numpy(v).to(device) for k, v in result.items()}
    result["num_nodes"] = lstnode
    result["num_edges"] = result["edge_index"].shape[1]
    return result


class StaticParaDict(torch.nn.Module):
    def __init__(self, **kwargs):
        super(StaticParaDict, self).__init__()
        for k, v in kwargs.items():
            if isinstance(v, torch.Tensor):
                setattr(self, k, torch.nn.Parameter(v, requires_grad=False))
            elif isinstance(v, np.ndarray):
                v = torch.from_numpy(v)
                setattr(self, k, torch.nn.Parameter(v, requires_grad=False))
            else:
                setattr(self, k, v)

    def forward(self, key: str) -> Any:
        return getattr(self, key)

    def __getitem__(self, key: str) -> Any:
        return self(key)

    def __setitem__(self, key: str, value: Any):
        if isinstance(value, np.ndarray):
            value = torch.from_numpy(value)
        if isinstance(value, torch.Tensor):
            value = torch.nn.Parameter(value, requires_grad=False)
        setattr(self, key, value)


class GINConv(torch.nn.Module):
    def __init__(self, embedding_dim: int = 64):
        super(GINConv, self).__init__()
        self.mlp = torch.nn.Sequential(
            torch.nn.Linear(embedding_dim, 2 * embedding_dim),
            torch.nn.BatchNorm1d(2 * embedding_dim),
            torch.nn.ReLU(),
            torch.nn.Linear(2 * embedding_dim, embedding_dim),
        )
        self.eps = torch.nn.Parameter(torch.Tensor([0]))
        self.bond_encoder = BondEncoder(emb_dim=embedding_dim)

    def forward(
        self,
        node_feats: torch.Tensor,
        edge_feats: torch.Tensor,
        edge_index: torch.Tensor,
        num_nodes: int,
        num_edges: int,
    ) -> torch.Tensor:
        edge_feats = self.bond_encoder(edge_feats)
        message_node = torch.index_select(input=node_feats, dim=0, index=edge_index[1])
        message = torch.relu(message_node + edge_feats)
        dim = message.shape[-1]

        message_reduce = torch.zeros(num_nodes, dim).to(message)
        index = edge_index[0].unsqueeze(-1).repeat(1, dim)
        message_reduce.scatter_add_(dim=0, index=index, src=message)

        return self.mlp((1 + self.eps) * node_feats + message_reduce)


class GINGraph(torch.nn.Module):
    def __init__(
        self, num_layers: int = 4, embedding_dim: int = 64, dropout: float = 0.7
    ):
        super(GINGraph, self).__init__()
        if num_layers < 2:
            raise ValueError("Number of GNN layers must be greater than 1.")

        self.atom_encoder = AtomEncoder(emb_dim=embedding_dim)
        self.convs = torch.nn.ModuleList()
        self.batch_norms = torch.nn.ModuleList()
        self.num_layers = num_layers
        self.dropout_fun = torch.nn.Dropout(dropout)
        for layer in range(self.num_layers):
            self.convs.append(GINConv(embedding_dim))
            self.batch_norms.append(torch.nn.BatchNorm1d(embedding_dim))

    def forward(self, graph: Dict[str, Union[int, torch.Tensor]]) -> torch.Tensor:
        h_list = [self.atom_encoder(graph["x"])]
        for layer in range(self.num_layers):
            h = self.batch_norms[layer](
                self.convs[layer](
                    node_feats=h_list[layer],
                    edge_feats=graph["edge_attr"],
                    edge_index=graph["edge_index"],
                    num_nodes=graph["num_nodes"],
                    num_edges=graph["num_edges"],
                )
            )
            if layer != self.num_layers - 1:
                h = self.dropout_fun(torch.relu(h))
            else:
                h = self.dropout_fun(h)
            h_list.append(h)

        batch_size, dim = graph["batch"].max().item() + 1, h_list[-1].shape[-1]
        out_feat = torch.zeros(batch_size, dim).to(h_list[-1])
        cnt = torch.zeros_like(out_feat).to(out_feat)
        index = graph["batch"].unsqueeze(-1).repeat(1, dim)

        out_feat.scatter_add_(dim=0, index=index, src=h_list[-1])
        cnt.scatter_add_(
            dim=0, index=index, src=torch.ones_like(h_list[-1]).to(h_list[-1])
        )

        return out_feat / (cnt + 1e-9)


class MAB(torch.nn.Module):
    def __init__(
        self, Qdim: int, Kdim: int, Vdim: int, number_heads: int, use_ln: bool = False
    ):
        super(MAB, self).__init__()
        self.Vdim = Vdim
        self.number_heads = number_heads

        assert (
            self.Vdim % self.number_heads == 0
        ), "the dim of features should be divisible by number_heads"

        self.Qdense = torch.nn.Linear(Qdim, self.Vdim)
        self.Kdense = torch.nn.Linear(Kdim, self.Vdim)
        self.Vdense = torch.nn.Linear(Kdim, self.Vdim)
        self.Odense = torch.nn.Linear(self.Vdim, self.Vdim)

        self.use_ln = use_ln
        if self.use_ln:
            self.ln1 = torch.nn.LayerNorm(self.Vdim)
            self.ln2 = torch.nn.LayerNorm(self.Vdim)

    def forward(self, X: torch.Tensor, Y: torch.Tensor) -> torch.Tensor:
        Q, K, V = self.Qdense(X), self.Kdense(Y), self.Vdense(Y)
        batch_size, dim_split = Q.shape[0], self.Vdim // self.number_heads

        Q_split = torch.cat(Q.split(dim_split, 2), 0)
        K_split = torch.cat(K.split(dim_split, 2), 0)
        V_split = torch.cat(V.split(dim_split, 2), 0)

        Attn = torch.matmul(Q_split, K_split.transpose(1, 2))
        Attn = torch.softmax(Attn / math.sqrt(dim_split), dim=-1)
        O = Q_split + torch.matmul(Attn, V_split)
        O = torch.cat(O.split(batch_size, 0), 2)

        O = O if not self.use_ln else self.ln1(O)
        O = self.Odense(O)
        O = O if not self.use_ln else self.ln2(O)

        return O


class SAB(torch.nn.Module):
    def __init__(
        self, in_dim: int, out_dim: int, number_heads: int, use_ln: bool = False
    ):
        super(SAB, self).__init__()
        self.net = MAB(in_dim, in_dim, out_dim, number_heads, use_ln)

    def forward(self, X: torch.Tensor) -> torch.Tensor:
        return self.net(X, X)


class AttnAgg(torch.nn.Module):
    def __init__(self, Qdim: int, Kdim: int, mid_dim: int):
        super(AttnAgg, self).__init__()
        self.model_dim = mid_dim
        self.Qdense = torch.nn.Linear(Qdim, mid_dim)
        self.Kdense = torch.nn.Linear(Kdim, mid_dim)
        # self.use_ln = use_ln

    def forward(
        self,
        main_feat: torch.Tensor,
        other_feat: torch.Tensor,
        fix_feat: torch.Tensor,
        mask: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        """Forward propagation.

        Adjusted Attention Aggregator

        Args:
            main_feat (torch.Tensor): shape of [main_num, Q_dim]
            other_feat (torch.Tensor): shape of [other_num, K_dim]
            fix_feat (torch.Tensor): shape of [batch, other_num],
                adjust parameter for attention weight
            mask (torch.Tensor): shape of [main_num, other_num] a mask
                representing where main object should have attention
                with other obj 1 means no attention should be done.
                (default: `None`)

        Returns:
            torch.Tensor: aggregated features, shape of
                [batch, main_num, K_dim]
        """
        Q = self.Qdense(main_feat)
        K = self.Kdense(other_feat)
        Attn = torch.matmul(Q, K.transpose(0, 1)) / math.sqrt(self.model_dim)

        if mask is not None:
            Attn = torch.masked_fill(Attn, mask, -(1 << 32))
        Attn = torch.softmax(Attn, dim=-1)

        batch_size = fix_feat.shape[0]
        # [batch_size, other_num, other_num]
        fix_feat = torch.diag_embed(fix_feat)
        # [batch_size, other_num, K_dim]
        other_feat = other_feat.repeat(batch_size, 1, 1)
        other_feat = torch.matmul(fix_feat, other_feat)
        Attn = Attn.repeat(batch_size, 1, 1)

        return torch.matmul(Attn, other_feat)


[docs]class MoleRecLayer(torch.nn.Module): """MoleRec model. Paper: Nianzu Yang et al. MoleRec: Combinatorial Drug Recommendation with Substructure-Aware Molecular Representation Learning. WWW 2023. This layer is used in the MoleRec model. But it can also be used as a standalone layer. Args: hidden_size: hidden feature size. coef: coefficient of ddi loss weight annealing. larger coefficient means higher penalty to the drug-drug-interaction. Default is 2.5. target_ddi: DDI acceptance rate. Default is 0.06. GNN_layers: the number of layers of GNNs encoding molecule and substructures. Default is 4. dropout: the dropout ratio of model. Default is 0.7. multiloss_weight: the weight of multilabel_margin_loss for multilabel classification. Value should be set between [0, 1]. Default is 0.05 """ def __init__( self, hidden_size: int, coef: float = 2.5, target_ddi: float = 0.08, GNN_layers: int = 4, dropout: float = 0.5, multiloss_weight: float = 0.05, **kwargs, ): super(MoleRecLayer, self).__init__() dependencies = ["ogb>=1.3.5"] # test whether the ogb and torch_scatter packages are ready try: pkg_resources.require(dependencies) global smiles2graph, AtomEncoder, BondEncoder from ogb.utils import smiles2graph from ogb.graphproppred.mol_encoder import AtomEncoder, BondEncoder except Exception as e: print( "Please follow the error message and install the [ogb>=1.3.5] packages first." ) print(e) self.hidden_size = hidden_size self.coef, self.target_ddi = coef, target_ddi GNN_para = { "num_layers": GNN_layers, "dropout": dropout, "embedding_dim": hidden_size, } self.substructure_encoder = GINGraph(**GNN_para) self.molecule_encoder = GINGraph(**GNN_para) self.substructure_interaction_module = SAB( hidden_size, hidden_size, 2, use_ln=True ) self.combination_feature_aggregator = AttnAgg( hidden_size, hidden_size, hidden_size ) score_extractor = [ torch.nn.Linear(hidden_size, hidden_size // 2), torch.nn.ReLU(), torch.nn.Linear(hidden_size // 2, 1), ] self.score_extractor = torch.nn.Sequential(*score_extractor) self.multiloss_weight = multiloss_weight
[docs] def calc_loss( self, logits: torch.Tensor, y_prob: torch.Tensor, ddi_adj: torch.Tensor, labels: torch.Tensor, label_index: Optional[torch.Tensor] = None, ) -> torch.Tensor: mul_pred_prob = y_prob.T @ y_prob # (voc_size, voc_size) ddi_loss = (mul_pred_prob * ddi_adj).sum() / (ddi_adj.shape[0] ** 2) y_pred = y_prob.detach().cpu().numpy() y_pred[y_pred >= 0.5] = 1 y_pred[y_pred < 0.5] = 0 y_pred = [np.where(sample == 1)[0] for sample in y_pred] loss_cls = binary_cross_entropy_with_logits(logits, labels) if self.multiloss_weight > 0 and label_index is not None: loss_multi = multilabel_margin_loss(y_prob, label_index) loss_cls = ( self.multiloss_weight * loss_multi + (1 - self.multiloss_weight) * loss_cls ) cur_ddi_rate = ddi_rate_score(y_pred, ddi_adj.cpu().numpy()) if cur_ddi_rate > self.target_ddi: beta = self.coef * (1 - (cur_ddi_rate / self.target_ddi)) beta = min(math.exp(beta), 1) loss = beta * loss_cls + (1 - beta) * ddi_loss else: loss = loss_cls return loss
[docs] def forward( self, patient_emb: torch.Tensor, drugs: torch.Tensor, average_projection: torch.Tensor, ddi_adj: torch.Tensor, substructure_mask: torch.Tensor, substructure_graph: Union[StaticParaDict, Dict[str, Union[int, torch.Tensor]]], molecule_graph: Union[StaticParaDict, Dict[str, Union[int, torch.Tensor]]], mask: Optional[torch.tensor] = None, drug_indexes: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: """Forward propagation. Args: patient_emb: a tensor of shape [patient, visit, num_substructures], representating the relation between each patient visit and each substructures. drugs: a multihot tensor of shape [patient, num_labels]. mask: an optional tensor of shape [patient, visit] where 1 indicates valid visits and 0 indicates invalid visits. substructure_mask: tensor of shape [num_drugs, num_substructures], representing whether a substructure shows up in one of the molecule of each drug. average_projection: a tensor of shape [num_drugs, num_molecules] representing the average projection for aggregating multiple molecules of the same drug into one vector. substructure_graph: a dictionary representating a graph batch of all substructures, where each graph is extracted via 'smiles2graph' api of ogb library. molecule_graph: dictionary with same form of substructure_graph, representing the graph batch of all molecules. ddi_adj: an adjacency tensor for drug drug interaction of shape [num_drugs, num_drugs]. drug_indexes: the index version of drugs (ground truth) of shape [patient, num_labels], padded with -1 Returns: loss: a scalar tensor representing the loss. y_prob: a tensor of shape [patient, num_labels] representing the probability of each drug. """ if mask is None: mask = torch.ones_like(patient_emb[:, :, 0]) substructure_relation = get_last_visit(patient_emb, mask) # [patient, num_substructures] substructure_embedding = self.substructure_interaction_module( self.substructure_encoder(substructure_graph).unsqueeze(0) ).squeeze(0) if substructure_relation.shape[-1] != substructure_embedding.shape[0]: raise RuntimeError( "the substructure relation vector of each patient should have " "the same dimension as the number of substructure" ) molecule_embedding = self.molecule_encoder(molecule_graph) molecule_embedding = torch.mm(average_projection, molecule_embedding) combination_embedding = self.combination_feature_aggregator( molecule_embedding, substructure_embedding, substructure_relation, torch.logical_not(substructure_mask > 0), ) # [patient, num_drugs, hidden] logits = self.score_extractor(combination_embedding).squeeze(-1) y_prob = torch.sigmoid(logits) loss = self.calc_loss(logits, y_prob, ddi_adj, drugs, drug_indexes) return loss, y_prob
[docs]class MoleRec(BaseModel): """MoleRec model. Paper: Nianzu Yang et al. MoleRec: Combinatorial Drug Recommendation with Substructure-Aware Molecular Representation Learning. WWW 2023. Note: This model is only for medication prediction which takes conditions and procedures as feature_keys, and drugs as label_key. It only operates on the visit level. Note: This model only accepts ATC level 3 as medication codes. Args: dataset: the dataset to train the model. It is used to query certain information such as the set of all tokens. embedding_dim: the embedding dimension. Default is 128. hidden_dim: the hidden dimension. Default is 128. num_rnn_layers: the number of layers used in RNN. Default is 1. num_gnn_layers: the number of layers used in GNN. Default is 4. dropout: the dropout rate. Default is 0.7. **kwargs: other parameters for the MoleRec layer. """ def __init__( self, dataset: SampleEHRDataset, embedding_dim: int = 64, hidden_dim: int = 64, num_rnn_layers: int = 1, num_gnn_layers: int = 4, dropout: float = 0.5, **kwargs, ): super(MoleRec, self).__init__( dataset=dataset, feature_keys=["conditions", "procedures"], label_key="drugs", mode="multilabel", ) dependencies = ["ogb>=1.3.5"] # test whether the ogb and torch_scatter packages are ready try: pkg_resources.require(dependencies) global smiles2graph, AtomEncoder, BondEncoder from ogb.utils import smiles2graph from ogb.graphproppred.mol_encoder import AtomEncoder, BondEncoder except Exception as e: print( "Please follow the error message and install the [ogb>=1.3.5] packages first." ) print(e) self.embedding_dim = embedding_dim self.hidden_dim = hidden_dim self.num_rnn_layers = num_rnn_layers self.num_gnn_layers = num_gnn_layers self.dropout = dropout self.dropout_fn = torch.nn.Dropout(dropout) self.feat_tokenizers = self.get_feature_tokenizers() self.label_tokenizer = self.get_label_tokenizer() self.embeddings = self.get_embedding_layers(self.feat_tokenizers, embedding_dim) self.label_size = self.label_tokenizer.get_vocabulary_size() self.ddi_adj = torch.nn.Parameter(self.generate_ddi_adj(), requires_grad=False) self.all_smiles_list = self.generate_smiles_list() substructure_mask, self.substructure_smiles = self.generate_substructure_mask() self.substructure_mask = torch.nn.Parameter( substructure_mask, requires_grad=False ) average_projection, self.all_smiles_flatten = self.generate_average_projection() self.average_projection = torch.nn.Parameter( average_projection, requires_grad=False ) self.substructure_graphs = StaticParaDict( **graph_batch_from_smiles(self.substructure_smiles) ) self.molecule_graphs = StaticParaDict( **graph_batch_from_smiles(self.all_smiles_flatten) ) self.rnns = torch.nn.ModuleDict( { x: torch.nn.GRU( embedding_dim, hidden_dim, num_layers=num_rnn_layers, dropout=dropout if num_rnn_layers > 1 else 0, batch_first=True, ) for x in ["conditions", "procedures"] } ) num_substructures = substructure_mask.shape[1] self.substructure_relation = torch.nn.Sequential( torch.nn.ReLU(), torch.nn.Linear(hidden_dim * 2, hidden_dim), torch.nn.ReLU(), torch.nn.Linear(hidden_dim, num_substructures), ) self.layer = MoleRecLayer( hidden_size=hidden_dim, dropout=dropout, GNN_layers=num_gnn_layers, **kwargs ) if "GNN_layers" in kwargs: raise ValueError("number of GNN layers is determined by num_gnn_layers") if "hidden_size" in kwargs: raise ValueError("hidden_size is determined by hidden_dim") # save ddi adj ddi_adj = self.generate_ddi_adj() np.save(os.path.join(CACHE_PATH, "ddi_adj.npy"), ddi_adj.numpy())
[docs] def generate_ddi_adj(self) -> torch.FloatTensor: """Generates the DDI graph adjacency matrix.""" atc = ATC() ddi = atc.get_ddi(gamenet_ddi=True) vocab_to_index = self.label_tokenizer.vocabulary ddi_adj = np.zeros((self.label_size, self.label_size)) ddi_atc3 = [ [ATC.convert(l[0], level=3), ATC.convert(l[1], level=3)] for l in ddi ] for atc_i, atc_j in ddi_atc3: if atc_i in vocab_to_index and atc_j in vocab_to_index: ddi_adj[vocab_to_index(atc_i), vocab_to_index(atc_j)] = 1 ddi_adj[vocab_to_index(atc_j), vocab_to_index(atc_i)] = 1 ddi_adj = torch.FloatTensor(ddi_adj) return ddi_adj
[docs] def generate_substructure_mask(self) -> Tuple[torch.Tensor, List[str]]: # Generates the molecular segmentation mask H and substructure smiles. all_substructures_list = [[] for _ in range(self.label_size)] for index, smiles_list in enumerate(self.all_smiles_list): for smiles in smiles_list: mol = Chem.MolFromSmiles(smiles) if mol is None: continue substructures = Chem.BRICS.BRICSDecompose(mol) all_substructures_list[index] += substructures # all segment set substructures_set = list(set(sum(all_substructures_list, []))) # mask_H mask_H = np.zeros((self.label_size, len(substructures_set))) for index, substructures in enumerate(all_substructures_list): for s in substructures: mask_H[index, substructures_set.index(s)] = 1 mask_H = torch.from_numpy(mask_H) return mask_H, substructures_set
[docs] def generate_smiles_list(self) -> List[List[str]]: """Generates the list of SMILES strings.""" atc3_to_smiles = {} atc = ATC() for code in atc.graph.nodes: if len(code) != 7: continue code_atc3 = ATC.convert(code, level=3) smiles = atc.graph.nodes[code]["smiles"] if smiles != smiles: continue atc3_to_smiles[code_atc3] = atc3_to_smiles.get(code_atc3, []) + [smiles] # just take first one for computational efficiency atc3_to_smiles = {k: v[:1] for k, v in atc3_to_smiles.items()} all_smiles_list = [[] for _ in range(self.label_size)] vocab_to_index = self.label_tokenizer.vocabulary for atc3, smiles_list in atc3_to_smiles.items(): if atc3 in vocab_to_index: index = vocab_to_index(atc3) all_smiles_list[index] += smiles_list return all_smiles_list
[docs] def generate_average_projection(self) -> Tuple[torch.Tensor, List[str]]: molecule_set, average_index = [], [] for smiles_list in self.all_smiles_list: """Create each data with the above defined functions.""" counter = 0 # counter how many drugs are under that ATC-3 for smiles in smiles_list: mol = Chem.MolFromSmiles(smiles) if mol is None: continue molecule_set.append(smiles) counter += 1 average_index.append(counter) average_projection = np.zeros((len(average_index), sum(average_index))) col_counter = 0 for i, item in enumerate(average_index): if item <= 0: continue average_projection[i, col_counter : col_counter + item] = 1 / item col_counter += item average_projection = torch.FloatTensor(average_projection) return average_projection, molecule_set
[docs] def encode_patient( self, feature_key: str, raw_values: List[List[List[str]]] ) -> torch.Tensor: codes = self.feat_tokenizers[feature_key].batch_encode_3d(raw_values) codes = torch.tensor(codes, dtype=torch.long, device=self.device) embeddings = self.embeddings[feature_key](codes) embeddings = torch.sum(self.dropout_fn(embeddings), dim=2) outputs, _ = self.rnns[feature_key](embeddings) return outputs
[docs] def forward( self, conditions: List[List[List[str]]], procedures: List[List[List[str]]], drugs: List[List[str]], **kwargs, ) -> Dict[str, torch.Tensor]: """Forward propagation. Args: conditions: a nested list in three levels with shape [patient, visit, condition]. procedures: a nested list in three levels with shape [patient, visit, procedure]. drugs: a nested list in two levels [patient, drug]. Returns: A dictionary with the following keys: loss: a scalar tensor representing the loss. y_prob: a tensor of shape [patient, visit, num_labels] representing the probability of each drug. y_true: a tensor of shape [patient, visit, num_labels] representing the ground truth of each drug. """ # prepare labels labels_index = self.label_tokenizer.batch_encode_2d( drugs, padding=False, truncation=False ) # convert to multihot labels = batch_to_multihot(labels_index, self.label_size) index_labels = -np.ones((len(labels), self.label_size), dtype=np.int64) for idx, cont in enumerate(labels_index): # remove redundant labels cont = list(set(cont)) index_labels[idx, : len(cont)] = cont index_labels = torch.from_numpy(index_labels) labels = labels.to(self.device) index_labels = index_labels.to(self.device) # encoding procs and diags condition_emb = self.encode_patient("conditions", conditions) procedure_emb = self.encode_patient("procedures", procedures) mask = torch.sum(condition_emb, dim=2) != 0 patient_emb = torch.cat([condition_emb, procedure_emb], dim=-1) substruct_rela = self.substructure_relation(patient_emb) # calculate loss loss, y_prob = self.layer( patient_emb=substruct_rela, drugs=labels, ddi_adj=self.ddi_adj, average_projection=self.average_projection, substructure_mask=self.substructure_mask, substructure_graph=self.substructure_graphs, molecule_graph=self.molecule_graphs, mask=mask, drug_indexes=index_labels, ) return { "loss": loss, "y_prob": y_prob, "y_true": labels, }