LSP Softmax Aggregator¶
LSP Softmax Aggregator
A differentiable aggregator that combines 5 canonical LSP operators using softmax-weighted mixing based on the tree’s andness parameter ‘a’:
The weights are computed via softmax based on distance from the tree’s ‘a’ parameter to each operator’s canonical andness value (configurable, default evenly-spaced).
- class bacon.aggregators.lsp.softmax_lsp.LspSoftmaxAggregator(tau: float = 0.5, eps: float = 1e-06, centers: list = None)¶
LSP Softmax Aggregator using 5 canonical operators.
Computes \(F(x, y) = \sum_i w_i A_i(x, y)\) where weights are determined by the tree’s andness parameter ‘a’:
\[w_i = \operatorname{softmax}\left(-\frac{(a - c_i)^2}{\tau}\right)\]Each operator has a canonical “center” andness value (default evenly-spaced):
\[A_0\,(\text{product})\!:\; a=1.5,\; A_1\,(\min)\!:\; a=1.0,\; A_2\,(\text{avg})\!:\; a=0.5,\; A_3\,(\max)\!:\; a=0.0,\; A_4\,(\text{prob\_sum})\!:\; a=-0.5\]- Concentration can be encouraged via:
Tau annealing: decrease tau over training (use anneal_tau() or set_tau())
Entropy regularization: add entropy_loss(a) to training loss to push tree’s ‘a’ toward canonical centers
- Parameters:
tau – Temperature for softmax (default: 0.5). Lower = sharper selection.
eps – Small constant for numerical stability (default: 1e-6).
centers – Custom operator centers (default: [1.5, 1.0, 0.5, 0.0, -0.5]).
- aggregate_float(values: Sequence[float], a: float, weights: Sequence[float]) float¶
Aggregate N scalar values with scalar andness a and N weights.
- aggregate_tensor(values: Sequence[Any], a: Tensor, weights: Sequence[Any] | Tensor) Tensor¶
Aggregate N tensors with andness a and N weights.
- anneal_tau(epoch: int, max_epochs: int, final_tau: float = 0.01, schedule: str = 'exponential')¶
Anneal temperature based on training progress.
- Parameters:
epoch – Current epoch (0-indexed)
max_epochs – Total number of epochs
final_tau – Target tau at end of training (default: 0.01)
schedule – “exponential” or “linear” (default: “exponential”)
Example
- for epoch in range(max_epochs):
agg.anneal_tau(epoch, max_epochs, final_tau=0.01) train_one_epoch(…)
- describe(a: float = 0.5) dict¶
Return interpretability info as dict for a given andness.
- entropy(a: float = 0.5) float¶
Compute entropy of weight distribution for a given andness. Higher entropy = more uniform, lower entropy = more polarized.
- entropy_loss(a: Tensor) Tensor¶
Compute differentiable entropy loss for use in training.
Minimizing this loss encourages the tree to learn ‘a’ values that are close to canonical centers, producing concentrated (low-entropy) weight distributions.
- Parameters:
a – Andness parameter tensor from tree (requires_grad=True)
- Returns:
Scalar entropy loss (add to training loss with weight lambda)
Example
loss = task_loss + 0.1 * agg.entropy_loss(a)
- forward(x: Tensor, y: Tensor, a: Tensor, w0: Tensor = None, w1: Tensor = None) Tensor¶
Compute mixed aggregation F(x,y) = sum_i w_i * A_i(x,y).
- Parameters:
x – First input tensor, values in [0, 1]
y – Second input tensor, values in [0, 1]
a – Andness parameter from tree (determines operator mixing)
w0 – Weight for x (used for pruning, optional)
w1 – Weight for y (used for pruning, optional)
- Returns:
Aggregated output tensor in [0, 1]
- get_weights_for_andness(a: float) ndarray¶
Return operator weights for a given andness value.
- set_tau(new_tau: float)¶
Update temperature (clamped to minimum 1e-4).
- property tau: float¶
Current temperature.
- class bacon.aggregators.lsp.softmax_lsp.PerNodeLspSoftmaxAggregator(tau: float = 1.0, eps: float = 1e-06)¶
Per-node LSP Softmax Aggregator for use with BACON trees.
Creates a separate set of operator logits for each internal node, similar to OperatorSetAggregator but using LSP operator basis.
- BACON integration:
Call attach_to_tree(num_layers) after tree is built
Call start_forward() before each forward pass
Call aggregate(left, right, a, w0, w1) for each node
- aggregate(left: Tensor, right: Tensor, a: Tensor, w0: Tensor, w1: Tensor) Tensor¶
Standard BACON aggregator signature. w0, w1 are used for pruning support (bypass inputs when one weight is ~0).
- attach_to_tree(num_layers: int)¶
Called once BACON knows how many internal nodes exist. Creates operator logits for each node.
- describe(node_index: int = None) dict | list¶
Return interpretability info.
- entropy(node_index: int = None) float | list¶
Compute entropy for a specific node (or all nodes).
- get_alpha(node_index: int = None) ndarray¶
Return raw logits for a specific node (or all nodes if None).
- get_weights(node_index: int = None) ndarray¶
Return weights for a specific node (or all nodes if None).
- start_forward()¶
Called at start of each forward pass to reset node pointer.