Source code for pyhealth.medcode.inner_map

import logging
import os
from abc import ABC, abstractmethod
from typing import List

import networkx as nx
import pandas as pd

import pyhealth.medcode as medcode
from pyhealth.medcode.utils import MODULE_CACHE_PATH, download_and_read_csv
from pyhealth.utils import load_pickle, save_pickle

logger = logging.getLogger(__name__)


# TODO: add this callable method: InnerMap(vocab)
[docs]class InnerMap(ABC): """Contains information for a specific medical code system. `InnerMap` is a base abstract class for all medical code systems. It will be instantiated as a specific medical code system with `InnerMap.load(vocabulary).` Note: This class cannot be instantiated using `__init__()` (throws an error). """ @abstractmethod def __init__( self, vocabulary: str, refresh_cache: bool = False, ): # abstractmethod prevents initialization of this class self.vocabulary = vocabulary pickle_filepath = os.path.join(MODULE_CACHE_PATH, self.vocabulary + ".pkl") csv_filename = self.vocabulary + ".csv" if os.path.exists(pickle_filepath) and (not refresh_cache): logger.debug(f"Loaded {vocabulary} code from {pickle_filepath}") self.graph = load_pickle(pickle_filepath) else: logger.debug(f"Processing {vocabulary} code...") df = download_and_read_csv(csv_filename, refresh_cache) # create graph df = df.set_index("code") self.graph = nx.DiGraph() # add nodes for code, row in df.iterrows(): row_dict = row.to_dict() row_dict.pop("parent_code", None) self.graph.add_node(code, **row_dict) # add edges for code, row in df.iterrows(): if "parent_code" in row: if not pd.isna(row["parent_code"]): self.graph.add_edge(row["parent_code"], code) logger.debug(f"Saved {vocabulary} code to {pickle_filepath}") save_pickle(self.graph, pickle_filepath) return def __repr__(self): return f"InnerMap(vocabulary={self.vocabulary}, graph={self.graph})"
[docs] @classmethod def load(_, vocabulary: str, refresh_cache: bool = False): """Initializes a specific medical code system inheriting from `InnerMap`. Args: vocabulary: vocabulary name. E.g., "ICD9CM", "ICD9PROC". refresh_cache: whether to refresh the cache. Default is False. Examples: >>> from pyhealth.medcode import InnerMap >>> icd9cm = InnerMap.load("ICD9CM") >>> icd9cm.lookup("428.0") 'Congestive heart failure, unspecified' >>> icd9cm.get_ancestors("428.0") ['428', '420-429.99', '390-459.99', '001-999.99'] """ cls = getattr(medcode, vocabulary) return cls(refresh_cache=refresh_cache)
@property def available_attributes(self) -> List[str]: """Returns a list of available attributes. Returns: List of available attributes. """ return list(list(self.graph.nodes.values())[0].keys())
[docs] def stat(self): """Prints statistics of the code system.""" print() print(f"Statistics for {self.vocabulary}:") print(f"\t- Number of nodes: {len(self.graph.nodes)}") print(f"\t- Number of edges: {len(self.graph.edges)}") print(f"\t- Available attributes: {self.available_attributes}") print()
[docs] @staticmethod def standardize(code: str) -> str: """Standardizes a given code. Subclass will override this method based on different medical code systems. """ return code
[docs] @staticmethod def convert(code: str, **kwargs) -> str: """Converts a given code. Subclass will override this method based on different medical code systems. """ return code
[docs] def lookup(self, code: str, attribute: str = "name"): """Looks up the code. Args: code: code to look up. attribute: attribute to look up. One of `self.available_attributes`. Default is "name". Returns: The attribute value of the code. """ code = self.standardize(code) return self.graph.nodes[code][attribute]
def __contains__(self, code: str) -> bool: """Checks if the code is in the code system.""" code = self.standardize(code) return code in self.graph.nodes
[docs] def get_ancestors(self, code: str) -> List[str]: """Gets the ancestors of the code. Args: code: code to look up. Returns: List of ancestors ordered from the closest to the farthest. """ code = self.standardize(code) # ordered ancestors ancestors = nx.ancestors(self.graph, code) ancestors = list(ancestors) ancestors = sorted( ancestors, key=lambda x: (nx.shortest_path_length(self.graph, x, code), x) ) return ancestors
[docs] def get_descendants(self, code: str) -> List[str]: """Gets the descendants of the code. Args: code: code to look up. Returns: List of ancestors ordered from the closest to the farthest. """ code = self.standardize(code) # ordered descendants descendants = nx.descendants(self.graph, code) descendants = list(descendants) descendants = sorted( descendants, key=lambda x: (nx.shortest_path_length(self.graph, code, x), x) ) return descendants
if __name__ == "__main__": icd9cm = InnerMap.load("ICD9CM") print(icd9cm.stat()) print("428.0" in icd9cm) print(icd9cm.lookup("4280")) print(icd9cm.get_ancestors("428.0")) print(icd9cm.get_descendants("428.0"))