Switch Diffusion Transformer: Synergizing Denoising Tasks with Sparse Mixture-of-Experts

Byeongjun Park, Hyojun Go, Jinyoung Kim, Sangmin Woo, Seokil Ham, Changick Kim

Korea Advanced Institute of Science and Technology (KAIST)    Twelvelabs
  Corresponding Author

Switch-DiT constructs parameter isolation between conflicted denoising tasks without losing semantic information, improving diffusion model architectures through sparse mixture-of-experts.

Summary

  • Our work explores how to leverage inter-task relationships between conflicted denoising tasks.
  • We propose a novel diffusion model architecture, Switch Diffusion Transformer (Switch-DiT), that constructs detailed inter-task relationships through sparse mixture-of-experts.
  • We show that Switch-DiT builds tailored denoising paths across various generation scenarios.
teaser
Figure: Qualtative results of Switch-DiT on ImageNet.

Method: Switch-DiT

1. Conceptualizing diffusion models as a form of multi-task learning.

Diffusion models are conceptualized as a multi-task learning, where they address a set of denoising tasks at each timestep \(t\). These tasks focus on reducing noise, which is trained to minimize the noise prediction loss \(L_{noise, t} = ||\epsilon - \epsilon_\theta(x_t, t)||_2^2\).


2. Switch Diffusion Transformer

overview
Figure: Overview of Switch Diffusion Transformer (Switch-DiT). Note that our Switch-DiT is built upon the DiT architecture.

Switch-DiT establish task-specific paths for each denoising task within a single neural network \(\epsilon_\theta\) by utilizing sparse mixture-of-experts (SMoE). For given \(i\)-th block's input \(z_i \) and M experts \(E^1_i, E^2_i, \dots, E^M_i \), our SMoE layers outputs \(m(z_i) \) as follows: \[m(z_i) = \sum_{j=1}^{M}g_{i, j}(e_t)E_i^j(z_i) \] where the gating output \( g_{i} \) is defined by \[ g_i(e_t) = TopK(Softmax(h_i(e_t)), k). \] Then, \(z_i \cdot m(z_i) \) is processed into the remaining transformer block, and \(z_i \cdot (1-m(z_i)) \) is skip-connected to the end.


3. Diffusion Prior Loss

gating_outputs
Figure: Gating Outputs Integration.

We first aggregate all gating outputs for each transformer block, and then integrated outputs are trained to match with those similarly derived from DTR (details in our paper). \[ L_{dp, t} = D_{JS} \Big( \frac{\tilde{p}_{tot}(e_t)}{N}, \frac{w_t^{prior}}{kN} \Big). \] Combined with the noise prediction loss, our Switch-DiT is trained by \( L_{noise, t} + \lambda_{dp} L_{dp, t} \).

Experimental Results

1. Improved quality of generated images

Figure: Consistent improvements in image quality. Also, Switch-DiT surpasses the tradeoff of DiT.


Switch-DiT shows significant improvement in image quality across different model sizes. Based on the observed strong correlation between GFLOPs and FID scores in DiT, we conclude that the performance enhancement achieved by Switch-DiT cannot solely be attributed to additional parameters and GFLOPs. Rather, Switch-DiT surpasses the trade-offs of DiT, leading to a notable performance improvement.



2. Tailored Denoising Path Construction



Figure: Even with the same configuration, entire denoising paths are varied depending on the model size and dataset.


Switch-DiT constructs tailored denoising paths across various generation scenarios, even when using the same diffusion prior.

BibTeX

@article{park2024switch,
  title={Switch Diffusion Transformer: Synergizing Denoising Tasks with Sparse Mixture-of-Experts},
  author={Park, Byeongjun and Go, Hyojun and Kim, Jin-Young and Woo, Sangmin and Ham, Seokil and Kim, Changick},
  journal={arXiv preprint arXiv:2403.09176},
  year={2024}
}

Acknowledgement

This website is adapted from Nerfies and LLaVA, licensed under a Creative Commons Attribution-ShareAlike 4.0 International License.