Graph#

The pyhealth.graph module lets you bring a healthcare knowledge graph into your PyHealth pipeline. Graph-based models like GraphCare and GNN can use relational medical knowledge — drug interactions, disease hierarchies, symptom–diagnosis links — to enrich patient representations beyond what raw EHR codes alone can capture.

What Is a Knowledge Graph?#

A knowledge graph encodes medical relationships as (head, relation, tail) triples. For example:

  • ("aspirin", "treats", "headache")

  • ("metformin", "used_for", "type_2_diabetes")

  • ("ICD9:250", "is_a", "ICD9:249")

PyHealth does not ship a built-in graph — you bring triples from a source of your choice (UMLS, DrugBank, an ICD hierarchy, a custom ontology, etc.) and the KnowledgeGraph class handles indexing, entity mappings, and k-hop subgraph extraction. The typical use case is querying the graph at training time: given a patient’s active codes, extract the local subgraph around those codes and feed it to a graph-aware model.

Getting Started#

The simplest way to create a graph is to pass a list of triples directly:

from pyhealth.graph import KnowledgeGraph

triples = [
    ("aspirin",    "treats",     "headache"),
    ("headache",   "symptom_of", "migraine"),
    ("ibuprofen",  "treats",     "headache"),
]
kg = KnowledgeGraph(triples=triples)
kg.stat()
# KnowledgeGraph: 4 entities, 2 relations, 3 triples

For larger graphs it is more practical to load from a CSV or TSV file. The file should have columns named head, relation, and tail:

kg = KnowledgeGraph(triples="path/to/medical_kg.tsv")

Exploring the Graph#

Once built, you can inspect the graph and look up neighbours for any entity:

kg.num_entities     # total unique entities
kg.num_relations    # total unique relation types
kg.num_triples      # total edges

kg.has_entity("aspirin")     # True / False
kg.neighbors("aspirin")      # list of (relation, tail) pairs

# Integer ID mappings used internally by PyG
kg.entity2id["aspirin"]      # → int
kg.id2entity[0]              # → entity name string

Extracting Patient Subgraphs#

The main reason to build a knowledge graph is to extract a patient-specific subgraph at training time. subgraph() returns all entities reachable within n hops of a set of seed codes, as a PyTorch Geometric Data object:

patient_codes = ["ICD9:250.00", "NDC:0069-0105"]
subgraph = kg.subgraph(seed_entities=patient_codes, num_hops=2)

Note

subgraph() requires PyTorch Geometric (torch_geometric). The graph can still be constructed and explored without it — only subgraph extraction needs PyG.

Install with: pip install torch-geometric

Using with GraphProcessor in a Task#

To feed subgraphs into a model automatically during data loading, pass a configured GraphProcessor instance in your task’s input_schema. The processor will call kg.subgraph() for each patient sample:

from pyhealth.graph import KnowledgeGraph
from pyhealth.processors import GraphProcessor
from pyhealth.tasks import BaseTask

kg = KnowledgeGraph(triples="medical_kg.tsv")

class MyGraphTask(BaseTask):
    task_name = "MyGraphTask"
    input_schema = {
        "conditions":  "sequence",
        "kg_subgraph": GraphProcessor(kg, num_hops=2),
    }
    output_schema = {"label": "binary"}

    def __call__(self, patient):
        ...

Pre-computed Node Embeddings#

If you already have entity embeddings (e.g. from TransE or an LLM), you can attach them to the graph at construction time. The model can then use these as initial node features instead of learning them from scratch:

import torch

node_features = torch.randn(kg.num_entities, 64)  # (num_entities, feat_dim)
kg = KnowledgeGraph(triples=triples, node_features=node_features)

API Reference#