Custom Layers Integration
If you have a non-sequential model (like a multi-branch transformer or a graph neural network) but want it to benefit from automated history management and dashboard logging, research-grade mHC integration is still remarkably simple.
1. The @mhc_compatible Decorator
This decorator is a core part of the "Honey Badger" ecosystem. It marks a class as a first-class citizen, signaling to utilities like inject_mhc and MHCLightningCallback that this module handles its own hyper-connectivity.
When to use it:
- You are building a custom
nn.Modulethat isn't a simple list of layers. - You want your custom layer's mixing weights to show up in the PyTorch Lightning monitoring dashboard automatically.
from mhc import mhc_compatible, MHCSkip
import torch.nn as nn
@mhc_compatible
class MySuperLayer(nn.Module):
def __init__(self, dim):
super().__init__()
# 1. Add the mixing engine
self.skip = MHCSkip(max_history=4)
self.layer = nn.Linear(dim, dim)
def forward(self, x, history=None):
# 2. Main transformation
residual = x
x = self.layer(x)
# 3. Apply mHC mix
# Note: 'history' is typically passed down by the container
x = self.skip(x, history or [])
return x
2. Manual History Management
When NOT using MHCSequential, you are the "Governor" of the HistoryBuffer. You must decide precisely when a state is important enough to be remembered.
Pattern: The Manual Buffer Loop
from mhc import HistoryBuffer, MHCSkip
class MyNonSequentialModel(nn.Module):
def __init__(self):
super().__init__()
# Create a buffer for tracking the last 4 feature maps
self.buffer = HistoryBuffer(max_history=4)
self.skip = MHCSkip(max_history=4)
self.layer = nn.Linear(64, 64)
def forward(self, x):
# 1. Fetch current history from the buffer
hist = self.buffer.get()
# 2. Transform input and apply the skip
out = self.layer(x)
out = self.skip(out, hist)
# 3. Update the buffer with the new state
# This will automatically evict the oldest state if count > max_history
self.buffer.append(out)
return out
3. Handling Resets in Training Loops
If you manage history manually via a HistoryBuffer attribute, remember that these buffers persist in the model's memory across batch iterations. In a standard training loop, you must clear the buffer between batches to prevent "Temporal Bleeding" (where features from batch 1 affect batch 2).
The "Auto-Reset" Pattern:
def training_step(self, batch, batch_idx):
# CRITICAL: Always clear history at the start of a training sample
self.model.clear_history()
x, y = batch
preds = self.model(x)
loss = F.cross_entropy(preds, y)
return loss
4. Why @mhc_compatible is Mandatory
Using the decorator ensures that the mhc ecosystem preserves your intentions:
- Dashboarding: The
MHCLightningCallbackperforms a recursive search for all modules. It only inspects and logs thelogitsof layers that are explicitly declared@mhc_compatible. - Projection Safety: If you use
inject_mhcon a large model that contains some of your custom layers, the injector will see the decorator and skip those layers to avoid wrapping them twice. - Device-Awareness: It allows standard mHC utilities to help move your nested buffers between CPU and GPU automatically.