Stage-Aware Hierarchical Attentive Relational Network for Diagnosis Prediction
Application of Hierarchical Attentive Relational Network in Diagnostic Prediction
In recent years, Electronic Health Records (EHR) have become extremely valuable in improving medical decision-making, online disease detection, and monitoring. At the same time, deep learning methods have also achieved great success in utilizing EHR for health risk prediction and diagnostic prediction. However, deep learning models usually require a large amount of data due to the vast number of parameters. Additionally, there are a large number of rare medical codes in EHR data, which presents a significant challenge for clinical applications. Therefore, some studies have proposed using medical ontologies to enhance predictive performance and provide interpretable prediction results. However, these medical ontologies are usually small in scale and coarse in granularity, lacking many diagnostic and medical concepts, not to mention the various relationships between these concepts.
To overcome this limitation, this paper proposes incorporating existing large-scale medical knowledge graphs (KG) into diagnostic prediction and designs a model named HAR (Hierarchical Attentive Relational Network). Specifically, for each visit, by extracting a personalized sub-KG from the existing medical KG, HAR performs relation-specific message passing and hierarchical message aggregation on this sub-KG to refine the node representations corresponding to the medical codes in the visit. HAR considers the specific stage of the patient’s disease progression, which participates in the calculation of relation-level and node-level attention. Extensive experiments on two public datasets show that HAR is effective in improving visit-level accuracy and code-level accuracy in diagnostic prediction tasks.
Research Background
Nowadays, Electronic Health Records (EHR) have become a widely used medical information technology. EHR data is represented chronologically, with each visit containing multiple medical codes representing clinical diagnoses. Studies have found that these EHR data can not only improve medical service efficiency but also be used for tasks such as medical concept extraction and disease prediction. Moreover, deep learning models have achieved great success in fields such as computer vision, natural language processing, graph neural networks, and data mining. Consequently, many deep learning-based methods have been proposed for modeling EHR data. These methods require less preprocessing and feature engineering and can achieve better performance.
However, EHR models based on deep learning usually require a large amount of data due to the vast number of parameters. Therefore, when the training dataset size is limited, the performance is often unsatisfactory. Additionally, a significant proportion of medical codes in EHR data appear infrequently, making it difficult to learn accurate representations for these rare medical codes. In this context, researchers have proposed introducing external medical knowledge into deep learning models to enhance their performance.
For example, GRAM (Graph-based Attention Model) introduced medical ontology—Clinical Classifications Software (CCS) into deep learning models through neural attention. However, the use of medical ontologies offers limited benefits in two major aspects: on the one hand, most medical ontologies are relatively small in scale. For example, CCS only contains hundreds of concepts, with most diagnoses lacking corresponding concepts; on the other hand, ontologies are essentially a disease classification tree, lacking various relationship information between different diseases, let alone reflecting disease progression. Therefore, we suggest introducing the existing large-scale medical knowledge graph SemMed (Semantic MEDLINE) into diagnostic prediction.
Although some studies have also proposed utilizing large-scale medical knowledge graphs, they fail to consider the specific stages of patients. For example, a diagnosis of fever at different stages may reflect different degrees of severity. An experienced doctor would take different approaches based on the different medical histories when facing a patient visiting for fever.
Research Sources
This paper is written by Liping Wang, Qiang Liu, Mengqi Zhang, Yaxuan Hu, Shu Wu, and Liang Wang and published in the IEEE Transactions on Knowledge and Data Engineering journal in April 2024. Part of the research in this paper is supported by the National Natural Science Foundation.
Research Methods
The HAR model consists of four main parts: the stage-aware relation-level attention module, the stage-aware node-level attention module, the relation-specific message passing module, and the hierarchical message aggregation module. This model is designed as a general module that can be combined with various temporal prediction models.
Personalized Graph Extraction
For each visit, a personalized sub-graph is extracted from the large-scale medical knowledge graph, and the model performs relation-specific message passing and hierarchical message aggregation on these sub-graphs. These personalized sub-graphs represent knowledge points relevant to the patient’s current disease state, avoiding information transmission between nodes unrelated to the patient.
Medical Code Embedding
It is crucial to convert discrete medical codes into reasonable and learnable representations. This paper uses a parameterized embedding matrix to encode medical codes into trainable embedding vectors, automatically learning the parameter matrix in an end-to-end manner.
Stage-Aware Relation-Level and Node-Level Attention
This mechanism considers the specific situation of the patient’s disease progression stage, assigning different weights to each relation type and neighboring node. Combined with the hidden vectors of the downstream prediction model, HAR assigns different weights, thus adopting a more discriminative prioritization during the computation of relation-level and node-level attention.
Relation-Specific Message Passing and Hierarchical Message Aggregation
In the personalized sub-graphs, HAR transmits information from source nodes to target nodes, ensuring the diversity of relation types and the importance of their information transmission. When aggregating information, a hierarchical approach is taken, which first performs message passing and updating among nodes of the same relation type, then aggregates among multiple relations, and finally sends the resulting node representation to the downstream prediction model.
End-to-End Training and Integration with Existing Prediction Models
HAR is co-trained with the existing prediction model P with the objective of multi-label classification, using the cross-entropy loss function. Through the gradient descent method, both the HAR model and the prediction model P are jointly optimized.
Experimental Results
This paper validates the effectiveness of HAR through extensive experiments on two public datasets, MIMIC-III and MIMIC-IV. The experimental results show that HAR improves predictive performance in terms of visit-level accuracy and code-level accuracy. In addition, ablation studies verified the rationality of the model architecture and the importance of each component to the overall performance of HAR. Case studies verified that the attention coefficients generated by HAR can provide explicit explanations for doctors’ diagnostic predictions.
Experimental Setup and Datasets
In the experiments, this paper selects two publicly available EHR datasets, MIMIC-III and MIMIC-IV. These datasets contain health records of ICU patients. The experiments focus on the diagnostic codes in these datasets, using these codes for prediction tasks. Moreover, the experiments also utilized the large-scale medical knowledge graph SemMed, which contains more than 150,000 entities and 64 types of relations.
Performance Comparison
Competing against various baseline models such as LSTM, RETAIN, DIPLOE, RAIM, StageNet, and HiTANet, the experimental results show that models combined with HAR outperform in both visit-level and code-level performance, especially in handling rare diseases and small datasets where HAR models show more significant improvements.
Ablation Study
Through ablation studies, the experimental results verified the necessity of stage-aware relation and node-level attention mechanisms in improving model performance and assessed the impact of different attention mechanisms.
Sensitivity Analysis
The sensitivity analysis of the hyperparameter λ shows that HAR maintains a certain level of robustness within the range of [0, 0.9], reflecting the effectiveness of external medical knowledge in the model.
Model Interpretability
Case studies show that by analyzing the attention coefficients generated by HAR, the relationships between different disease symptoms can be revealed, providing explicit explanations for clinical physicians.
Conclusion
The proposed HAR model effectively overcomes the challenges of data volume and rare medical codes in existing diagnostic prediction models, integrating a large-scale medical knowledge graph into diagnostic prediction and achieving significant performance improvements. Further studies show that the HAR model also exhibits good interpretability, potentially providing strong support in clinical practical applications.