Variational mHC
Variational mHC introduces stochasticity into the mixing weights, turning the hyper-connections into a Bayesian-inspired ensemble of historical states.
Stochastic Mixing Math
When stochastic=True is enabled, the model samples mixing weights from a categorical distribution using the Gumbel-Softmax trick.
Sampling from a discrete categorical distribution is normally non-differentiable. Gumbel-Softmax provides a continuous, differentiable approximation:
Where:
- \(\pi_i\) are the manifold-constrained mixing weights (probabilities).
- \(g_i\) are independent samples drawn from the Gumbel(0, 1) distribution: \(g = -\log(-\log(u))\) where \(u \sim \text{Uniform}(0, 1)\).
- \(\tau\) is the temperature.
Sampling regimes:
- Low Temperature (\(\tau \to 0\)): Samples become nearly one-hot (categorical). The model picks exactly one historical state to follow.
- High Temperature (\(\tau \to \dots\)): Samples become uniform. The model averages history.
Why use Variational Mixing?
- Ensemble Robustness: By sampling different history paths during training, the model acts as an ensemble. This prevents overfitting to specific layer connections.
- Manifold Exploration: Stochasticity forces the gradient to explore alternative historical connections that might be ignored by a greedy deterministic optimizer.
- Uncertainty Quantification: In inference, multiple stochastic passes can be used to measure the variance (uncertainty) of the manifold connections.
Technical Behavior
Training vs. Evaluation
Variational mixing respects the PyTorch module state:
- model.train(): Stochastic. Gumbel-Softmax is active.
- model.eval(): Deterministic. It uses the exact manifold projection \(\pi_i\) without Gumbel noise.
Parameters
Set these in MHCSequential or MHCSkip:
stochastic=True
- temperature: We recommend starting with 1.0 and annealing to 0.5 for sharper connections.