AugDiff: Diffusion-Based Feature Augmentation for Multiple Instance Learning in Whole Slide Image
Diffusion-Based Feature Augmentation: A Novel Approach for Multiple Instance Learning in Whole Slide Images
Academic Background and Research Motivation
In computational pathology, effectively analyzing Whole Slide Images (WSIs) is a burgeoning area of research. WSIs are ultra-high-resolution images with a broad field of view and are widely employed in cancer diagnosis. However, due to the scarcity of labeled samples and the massive size of these images, employing Multiple Instance Learning (MIL) for automated WSI analysis introduces numerous challenges.
MIL is a classic weakly supervised learning approach that treats an entire WSI as a “bag” and the smaller image patches as individual “instances.” While the labels at the bag level are known, the labels at the instance level remain unknown. Challenges in MIL applications include overfitting due to insufficient training data and computational overhead caused by the vast number of instances. To address these issues, image augmentation poses a potential solution. However, traditional image augmentation methods, such as rotation or stretching, while improving the generalization of models, are inefficient when dealing with thousands of image patches and can introduce redundant information, potentially affecting subsequent tasks.
Recently, feature-level augmentation frameworks have gained prominence. Compared to image augmentation, these methods operate directly on the feature space, eliminating repetitive image processing steps, thus improving efficiency. However, existing feature augmentation methods—such as Mixup-based linear feature blending or Generative Adversarial Network (GAN)-based feature generation—often lack either diversity or stability in the generated features, limiting the overall quality and effectiveness of data augmentation. Diffusion Models (DMs), as an emerging generative model, have demonstrated superior performance in diversity and stability, making them a promising candidate. This study introduces DMs into the MIL framework for the first time and proposes a novel feature augmentation framework called AugDiff. AugDiff aims to provide high-quality online feature augmentation, addressing the shortcomings of existing methods.
Paper Source and Authors
This paper, authored by Zhuchen Shao, Liuxi Dai, Yifeng Wang, Haoqian Wang, and Yongbing Zhang, is primarily affiliated with Tsinghua Shenzhen International Graduate School and Harbin Institute of Technology (Shenzhen). Published in IEEE Transactions on Artificial Intelligence, December 2024 (Vol. 5, No. 12), the article is titled “AugDiff: Diffusion-Based Feature Augmentation for Multiple Instance Learning in Whole Slide Images.” The work was funded by China’s National Key Research and Development Program as well as Shenzhen Science and Technology projects. The code is open-sourced at https://github.com/szc19990412/augdiff.
Research Methodology and Workflow
Overall Framework:
The essence of AugDiff lies in leveraging DMs’ step-by-step generative characteristics for augmenting features while preserving original semantic information. The design of AugDiff includes the following steps:
1. WSI Splitting and Feature Extraction: WSIs are divided into multiple patches, and patch features are extracted using pretrained feature extractors (e.g., ResNet18 and RegNetX).
2. Training the Diffusion Model: AugDiff trains the diffusion model to learn feature generation processes using features of images augmented through various techniques (e.g., random rotation, color jittering).
3. Integration with MIL Training: During MIL training, AugDiff dynamically generates augmented features, enabling online learning for the MIL model.
Design of the Diffusion Model:
DM is divided into two phases: forward diffusion and reverse diffusion. Forward diffusion gradually adds noise to the data, converting its distribution into a Gaussian distribution; reverse diffusion employs a denoising autoencoder (DAE) to remove the noise step-by-step, restoring the original data distribution. The innovation of AugDiff lies in using original features as the starting “noise” and controlling the retention and augmentation of semantic information during each sampling iteration.
Key Algorithms and Implementations:
The sampling process in AugDiff employs a two-stage approach:
- K-Step Diffusion: Gradually adds noise to the original features.
- K-Step Denoising: Uses the DAE to iteratively denoise features and generate augmented data.
Additionally, AugDiff employs six image augmentation methods (e.g., elastic deformations, Gaussian blur) to guide the training of the diffusion model. Core algorithmic details are provided in the paper’s appendix, with the training objective outlined as:
$$ L{DM} = E{x, \epsilon \sim N(0,1), t} [||\epsilon - \epsilon_{\theta}(x_t, t)||^2] $$
Experimental Setup and Datasets:
The study evaluates AugDiff on four cancer datasets: Prostate Cancer (SICAPv2), Colorectal Cancer (Unitopatho), Pancreatic Cancer (TMAS), and Breast Cancer (Camelyon16). Two feature extractors (ResNet18, RegNetX) and three classic MIL methods (AMIL, LossAttn, DSMIL) were used to validate AugDiff’s applicability and robustness.
Results and Analysis
1. Performance Improvement:
AugDiff consistently demonstrates superior results across different datasets. For example, in the prostate cancer dataset (SICAPv2, ResNet18), AugDiff achieved an average AUC of 0.749, a roughly 4% improvement over setups without feature augmentation.
2. Efficiency:
AugDiff significantly reduces the computational cost associated with traditional image augmentation. On WSIs with 10× magnification, AugDiff was over 30 times faster than image-level augmentation frameworks and demonstrated greater memory efficiency. Its lightweight design makes AugDiff suitable for high-resolution WSIs.
3. Augmentation Quality and Diversity:
Features augmented via AugDiff were closer to actual data distributions, avoiding the instability issues associated with GANs. UMAP-based visualizations revealed that AugDiff-generated samples exhibited more natural structures than those produced via image augmentation.
4. Generalization Ability:
Pretrained AugDiff models performed well on external datasets, consistently outperforming traditional image augmentation methods. For example, on the Unitopatho dataset, AugDiff achieved an average AUC of 0.911, surpassing multiple feature and image augmentation baselines.
Significance and Future Directions
AugDiff has advanced both the theoretical and practical realms of MIL and WSI analysis in several ways:
1. Scientific Contributions: Introducing DMs for feature augmentation in MIL for the first time, AugDiff highlights the utility of this innovative approach in computational pathology, offering solutions for modeling scarce pathological data.
2. Practical Impact: AugDiff’s efficient online feature augmentation provides a valuable tool for rapid data generation, particularly for rare cancer diagnoses.
3. Future Directions: AugDiff can be extended to semi-supervised learning tasks or extremely high-resolution WSI datasets. Open questions remain regarding how to optimize control over augmentation processes and refine the mapping between augmented and original features.
AugDiff presents a novel, efficient, and stable method for feature augmentation within MIL frameworks, demonstrating both innovative theoretical foundations and immense potential for practical applications.