Federated Learning Using Model Projection for Multi-Center Disease Diagnosis with Non-IID Data
Federated Learning Using Model Projection for Multi-Center Disease Diagnosis
Background Introduction
With the rapid development of medical imaging technology, research on automated diagnostic methods has shown good performance on single-center datasets. However, these methods often find it difficult to generalize to data from other healthcare facilities in practical applications. The primary reason is that these methods typically assume the data from different medical centers are independently and identically distributed (IID), whereas, in reality, data distributions are non-independently and identically distributed (Non-IID) due to the use of different scanners and imaging parameters at different centers. Additionally, there are significant differences in the number and types of patients diagnosed at different centers. Therefore, multi-center data is heterogeneous and cannot be effectively addressed by centralized learning.
In recent years, federated learning (FL) has emerged as a decentralized framework that allows collaborative training of a global model across multiple centers while preserving the privacy of patient data at each center. However, the application of Non-IID data in federated learning still faces two key problems: catastrophic forgetting on the client side and invalid aggregation on the server side.
Paper Source
This paper is authored by Jie Du, Wei Li, Peng Liu, Chi-Man Vong, Yongke You, Baiying Lei, and Tianfu Wang from Shenzhen University and the University of Macau. It is published in the forthcoming issue of the journal Neural Networks, with the manuscript accepted on May 23, 2024. The citation is recommended as follows:
Du, J., Li, W., Liu, P., Vong, C. M., You, Y., Lei, B., & Wang, T. (2024). Federated learning using model projection for multi-center disease diagnosis with non-iid data. Neural Networks. doi: https://doi.org/10.1016/j.neunet.2024.106409
Research Workflow
Method Overview
This study proposes an innovative federated learning method using model projection (FedMOP), aimed at addressing the issues of catastrophic forgetting on the client side and invalid aggregation on the server side. The core idea of FedMOP is to achieve the following goals through a model projection approach: 1. Avoid increasing global data loss after local training on the client side (to prevent performance degradation). 2. Prevent increases in local data loss on the global model (to improve convergence speed).
Model Projection on Client Side (MPC)
During each communication round, the clients first perform multiple rounds of local training, then execute MPC to reduce the problem of catastrophic forgetting. This method ensures that the client model’s loss on the global data does not increase through local linear function approximation and optimization constraints, thereby effectively retaining the knowledge received from the server.
Model Projection on Server Side (MPS)
Upon receiving the model parameters uploaded by the clients, the server uses the MPS method for aggregation to reduce the problem of invalid aggregation. Similarly, through optimization constraints, it ensures that the global model’s loss on the local data does not increase, thereby accelerating convergence.
Experimental Design
Three sets of real medical imaging datasets were selected for the study, representing Non-IID problems of feature distribution skew and label distribution skew. Comparative experiments were conducted using five popular federated learning methods, including FedAvg, FedProx, Scaffold, FedAGrac, and FedReg.
Main Results
Feature Distribution Skew Scenario
In both two-client and four-client scenarios, FedMOP significantly outperformed other federated learning methods. Particularly in the four-client scenario, FedMOP notably improved accuracy, with an increase of 3.73% over traditional methods.
Label Distribution Skew Scenario
In the uneven and single-class scenarios of the Covid-19 and PBC datasets, FedMOP exhibited excellent performance. In uneven scenarios, its accuracy increased by at least 1.70%, and in single-class scenarios, especially in the single-class scenario of the PBC dataset, its accuracy increased by at least 2.47%.
Convergence Speed and Communication Cost
FedMOP not only has a high convergence speed but also effectively reduces communication costs. In most experiments, FedMOP achieved the highest accuracy with the lowest communication cost, demonstrating its effectiveness and efficiency in solving Non-IID problems.
Ablation Studies
The study verified the effectiveness of MPC and MPS through ablation experiments. The ablation study results showed that removing MPC alone would cause a performance drop, while removing MPS had a more significant impact, indicating that both model projection methods play important roles in improving model performance.
Conclusion and Value
The proposed FedMOP in this study performs excellently when addressing Non-IID problems of multi-center medical imaging data. It not only effectively retains the knowledge obtained from client-side training and improves the adaptability of the global model to local data, but also significantly reduces communication costs and enhances convergence speed. Most importantly, under the premise of protecting data privacy, FedMOP achieves an accuracy comparable to or even higher than centralized learning. This study provides new perspectives and methods for the application of federated learning in medical data, opening up new avenues for ensuring data privacy and improving model performance.
Future Research Directions
In the future, the research team plans to apply FedMOP to medical image segmentation tasks to further verify its effectiveness in more complex medical applications. FedMOP has achieved remarkable success in addressing Non-IID data challenges through clever model projection methods, providing a practical and efficient solution for the application of federated learning in healthcare.