← Back to homepage
Affect-Diff: Multimodal Emotion Recognition with a Diffusion Generative Prior
Work in Progress
A research system that fuses text, audio, and video into a shared VAE latent space,
then uses a causal graph + 1D diffusion U-Net to model the distribution of affect,
with explainability via counterfactual generation and Gumbel-Softmax causal edges.
Diffusion Models
Multimodal Fusion
Causal Reasoning
VAE
CMU-MOSEI
PyTorch Lightning
Classifier-Free Guidance
Wandb / SLURM
Python 99%
Motivation
Emotion recognition from multimodal signals (speech, language, face) is a core problem in affective
computing, HCI, and clinical AI. Existing methods mostly treat it as a classification problem.
they fuse modalities and predict a label. This ignores two critical questions:
- Which modality actually caused this emotion? (Text? Tone? Facial expression?)
- What would the emotion have been if one modality had been different? (Counterfactual explainability)
Affect-Diff addresses both. A causal attention graph learns which modality drives each prediction.
A conditional diffusion model then allows counterfactual hallucination, sampling what the latent
affect would look like under a different emotion label, enabling model-level interpretability.
Architecture Overview
↓ LayerNorm per modality (energy-level normalization) ↓
CausalAttentionGraph
Differentiable 3×3 adjacency matrix (T↔A↔V) via Gumbel-Softmax with temperature annealing. Masks self-loops. Returns causal influence weights per modality.
↓ Causal-weighted modality fusion → VAE bottleneck ↓
VAE Latent Bottleneck
fc_mu + fc_logvar → reparameterize → z (B, T, latent_dim=256). β-KL loss for regularization. Logvar clamped [−10, 2] for stability.
↓ DDPM forward process: q(z_t | z_0) ↓
1D U-Net Diffusion (AffectiveDiffusion)
Cosine β-schedule. ResnetBlock1D with SiLU + GroupNorm. Time embeddings + label embeddings (CFG) + causal influence projection, all summed into a single conditioning vector. Classifier-Free Guidance at inference (cfg_scale > 1.0).
↓ p_sample_loop → reconstructed z → classifier ↓
Emotion Classifier + Counterfactual
Linear classifier on denoised z. Counterfactual hallucination: sample z under a different label to explain why the model changed its prediction.
Key Technical Contributions
Causal Attention Graph
Learns a differentiable adjacency matrix (T↔A↔V) using scaled dot-product + Gumbel-Softmax. Temperature anneals over training from soft to hard edges. Self-loops masked. Output: per-sample causal influence weights injected into the diffusion U-Net.
1D Diffusion U-Net
Custom UNet1D with down/upsample ResnetBlock1D stacks operating on the temporal latent sequence. Conditions simultaneously on: (1) sinusoidal timestep embedding, (2) label embedding for CFG, (3) causal influence projection. Skip connections at all resolutions.
Classifier-Free Guidance
At inference, the model runs two forward passes (conditioned + unconditional using a null label) and interpolates: pred_noise = uncond + scale × (cond − uncond). Enables controllable emotion-conditioned generation.
Counterfactual Explainability
Given a predicted emotion, sample z under a different target label. The delta between the two reconstructions reveals which features drove the original decision, providing causal attribution beyond standard attention visualization.
Training Setup
Dataset : CMU-MOSEI (23,454 utterances, 6 emotion categories)
Backbone: RoBERTa (text) · wav2vec (audio) · facial AUs (video)
Training: PyTorch Lightning, DDP (multi-GPU, SLURM cluster)
Optimizer: AdamW, gradient clipping (norm=1.0, DeepMind standard)
Logging : Weights & Biases (project: "Affect-Diff-CVPR")
Monitor : val_loss (patience=50), EarlyStopping + LR warmup
Combined Loss = CE (classification)
+ λ₁ · MSE (diffusion noise prediction)
+ λ₂ · β-KL (VAE regularization)
Module Structure
models/
encoders/
text_encoder.py # Transformer projection for transcript features
audio_encoder.py # Conv1D projection for wav2vec features
video_encoder.py # FC projection for facial action units
fusion/
latent_bottleneck.py # VAE + CausalAttentionGraph + weighted fusion
diffusion/
unet_1d.py # 1D U-Net with time/label/causal conditioning
diffusion_utils.py # DDPM forward/reverse, CFG, cosine schedule
causal_graph.py # CausalAttentionGraph (standalone, reusable)
modules/
affect_diff_module.py # PyTorch Lightning training module
Research angle: This project bridges three active areas: multimodal representation learning,
score-based generative models for structured latent spaces, and causal explainability.
The counterfactual generation capability positions it as an interpretability tool, not just a classifier.
Target venue: CVPR / EMNLP 2026.
Tech Stack
- Core: PyTorch 2.x, PyTorch Lightning (DDP, multi-GPU), Hydra (config management)
- Pretrained backbones: RoBERTa (text), wav2vec 2.0 (audio), facial action units (video)
- Dataset: CMU-MOSEI (6-class sentiment/emotion, ~23K utterances)
- Training infra: SLURM cluster, W&B logging, ModelCheckpoint + EarlyStopping
- Explainability: Gumbel-Softmax causal graph, counterfactual hallucination via DDPM