BACON Module¶
- class bacon.FullWeightAggregator(*args, **kwargs)¶
Bases:
AggregatorBase- aggregate_float(values: Sequence[float], a: float, weights: Sequence[float]) float¶
Aggregate N scalar values with scalar andness a and N weights.
- aggregate_tensor(values: Sequence[Any], andness, weights)¶
Aggregates two tensors using the Full Weight method.
- Parameters:
x1 (torch.Tensor) – First input tensor.
x2 (torch.Tensor) – Second input tensor.
andness (float) – Andness.
w0 (float) – Weight for the first tensor.
w1 (float) – Weight for the second tensor.
- Returns:
Resulting tensor after aggregation.
- Return type:
torch.Tensor
- class bacon.HalfWeightAggregator(*args, **kwargs)¶
Bases:
AggregatorBase- aggregate_float(values: Sequence[float], a: float, weights: Sequence[float]) float¶
Aggregate N scalar values with scalar andness a and N weights.
- aggregate_tensor(values: Sequence[Any], andness, weights) Any¶
Aggregate N tensors with andness a and N weights.
- r(a)¶
- class bacon.MinMaxAggregator(*args, **kwargs)¶
Bases:
AggregatorBaseAggregate a sequence of inputs with min/max interpolation.
The aggregator computes:
elementwise minimum when
andnessis near 1elementwise maximum when
andnessis near 0
with a straight-through hard gating step so gradients can still flow through the continuous proxy during training.
- aggregate_float(values: Sequence[float], a: float, weights: Sequence[float]) float¶
Float-friendly wrapper around
aggregate_tensor().- Parameters:
values – Input scalar values to aggregate.
a – Andness control value.
weights – Optional per-input weights used as soft gates.
- Returns:
Python float result.
- aggregate_tensor(values: Sequence[Any], andness, weights=None)¶
Aggregate one or more tensor-like values.
- Parameters:
values – Non-empty sequence of tensors with compatible shapes.
andness – Control signal where larger values bias toward min/AND and smaller values bias toward max/OR.
weights – Optional per-input gates in
[0, 1]. When provided, each input is blended with a neutral baseline before reduction.
- Returns:
Tensor with the same broadcast-compatible shape as each input.
- Raises:
ValueError – If
valuesis empty.
baconNet¶
- class bacon.baconNet.TrainingSetup(optimizer: Optimizer, criterion: Module, pos_weight: float | None, param_groups: list, aggregation_frozen: bool, actual_max_epochs: int, use_temperature_annealing: bool, perm_temp_decay_rate: float | None, trans_temp_decay_rate: float | None, operator_tau_decay_rate: float | None, operator_initial_tau: float | None, anneal_over_epochs: int, original_sparsity_weight: float | None, task_type: str, loss_history: list, accuracy_history: list, freeze_confidence_warning_shown: bool, best_loss_for_convergence: float, epochs_without_improvement: int, epoch_when_frozen: int | None, has_converged_before_freeze: bool, best_confidence: float, epochs_without_confidence_improvement: int, best_loss: float, epochs_since_improvement: int, temp_paused: bool, transformation_converged: bool, best_frozen_loss: float | None, best_frozen_state: dict | None, frozen_lr_reduced: bool, best_overall_loss: float | None, best_overall_state: dict | None, best_overall_epoch: int | None)¶
Bases:
objectEncapsulates all training setup state for a single attempt.
- accuracy_history: list¶
- actual_max_epochs: int¶
- aggregation_frozen: bool¶
- anneal_over_epochs: int¶
- best_confidence: float¶
- best_frozen_loss: float | None¶
- best_frozen_state: dict | None¶
- best_loss: float¶
- best_loss_for_convergence: float¶
- best_overall_epoch: int | None¶
- best_overall_loss: float | None¶
- best_overall_state: dict | None¶
- criterion: Module¶
- epoch_when_frozen: int | None¶
- epochs_since_improvement: int¶
- epochs_without_confidence_improvement: int¶
- epochs_without_improvement: int¶
- freeze_confidence_warning_shown: bool¶
- frozen_lr_reduced: bool¶
- has_converged_before_freeze: bool¶
- loss_history: list¶
- operator_initial_tau: float | None¶
- operator_tau_decay_rate: float | None¶
- optimizer: Optimizer¶
- original_sparsity_weight: float | None¶
- param_groups: list¶
- perm_temp_decay_rate: float | None¶
- pos_weight: float | None¶
- task_type: str¶
- temp_paused: bool¶
- trans_temp_decay_rate: float | None¶
- transformation_converged: bool¶
- use_temperature_annealing: bool¶
- class bacon.baconNet.baconNet(input_size, tree_layout='left', loss_amplifier=1, weight_penalty_strength=0.001, weight_mode='trainable', weight_normalization='minmax', aggregator='lsp.full_weight', normalize_andness=True, is_frozen=False, use_transformation_layer=False, transformation_temperature=None, transformation_use_gumbel=False, transformations=None, early_stop_threshold_large_inputs=0.1, permutation_initial_temperature=5.0, permutation_final_temperature=0.1, transformation_initial_temperature=1.0, transformation_final_temperature=0.1, loss_weight_main=1.0, loss_weight_perm_entropy=0.0, loss_weight_trans_entropy=0.0, loss_weight_perm_sparsity=0.01, loss_weight_operator_sparsity=1.0, loss_weight_operator_l2=0.0, lr_permutation=0.3, lr_transformation=0.5, lr_aggregator=0.1, lr_other=0.1, use_class_weighting=True, loss_trim_percentile: float = 0.0, loss_trim_mode: str = 'drop_high', loss_trim_start_epoch: int = 0, training_policy=None, full_tree_depth: int = None, full_tree_shape: str = 'triangle', full_tree_temperature: float = 3.0, full_tree_final_temperature: float = 0.1, full_tree_max_egress: int = None, full_tree_concentrate_ingress: bool = False, full_tree_use_sinkhorn: bool = False, loss_weight_full_tree_egress: float = 0.0, loss_weight_full_tree_ingress: float = 0.5, loss_weight_full_tree_ingress_balance: float = 0.0, loss_weight_full_tree_scale_reg: float = 0.0, alternating_learn_first_routing: bool = True, alternating_learn_subsequent_routing: bool = True, alternating_learn_exponents: bool = False, alternating_min_exponent: float = 1.0, alternating_max_exponent: float = 2.0, alternating_max_egress: int = 1, alternating_use_straight_through: bool = True, loss_weight_alternating_balance: float = 50.0, loss_weight_alternating_egress: float = 0.5, loss_weight_alternating_exponent_reg: float = 0.0, use_constant_input: bool = False, use_permutation_layer: bool = True, regression_loss_type: str = 'mse')¶
Bases:
ModuleRepresents a BACON network for interpretable decision-making using graded logic.
- Parameters:
input_size (int) – Number of input features. This is likely to be removed in the future.
tree_layout (str, optional) – Layout of the tree. Defaults to “left”. Other layouts are not supported yet.
loss_amplifier (float, optional) – Amplifier for the loss. Defaults to 1.
weight_penalty_strength (float, optional) – Penalty strength on weights. Defaults to 1e-3. A strong penalty leads to more balaned weights (closer to 0.5).
normalize_andness (bool, optional) – Whether to normalize andness. Defaults to True. This should set to False if the chosen aggregator, such as bool.min_max, already normalizes the andness.
weight_mode (str, optional) – Mode for weight configuration. Defaults to “trainable”. Use “fixed” for fixed weights (set to 0.5).
aggregator (str, optional) – Aggregator to be used. Defaults to “lsp.full_weight”.
is_frozen (bool, optional) – Whether to freeze the structure. Defaults to False.
early_stop_threshold_large_inputs (float, optional) – Early stop threshold for transformation layers with 10+ inputs. Defaults to 0.1. Lower values require more training but achieve higher accuracy.
transformations (list, optional) – List of transformation objects to use. If None, uses all 6 default transformations. Example: [IdentityTransformation(n), NegationTransformation(n)] for identity+negation only.
permutation_initial_temperature (float, optional) – Starting temperature for permutation annealing. Defaults to 5.0. Higher = more initial exploration.
permutation_final_temperature (float, optional) – Final temperature for permutation annealing. Defaults to 0.1. Lower = harder final permutation.
transformation_initial_temperature (float, optional) – Starting temperature for transformation layer. Defaults to 1.0. Should be lower than permutation since transformation is simpler (2^n vs n! states).
transformation_final_temperature (float, optional) – Final temperature for transformation layer. Defaults to 0.1. Same as permutation final temp.
loss_weight_main (float, optional) – Weight for main BCE loss. Defaults to 1.0.
loss_weight_perm_entropy (float, optional) – Weight for permutation entropy regularization. Defaults to 0.0. Higher = encourage exploration. Typical range: 0.0-0.1.
loss_weight_trans_entropy (float, optional) – Weight for transformation entropy regularization. Defaults to 0.0. Higher = encourage decisive transformation selection. Typical range: 0.0-0.1.
loss_weight_perm_sparsity (float, optional) – Weight for permutation sparsity loss. Defaults to 0.01. Penalizes high entropy (flat distributions) to encourage peaked/sparse permutations. Higher = stronger push toward confident or clear multi-modal distributions. Typical range: 0.0-0.1.
loss_weight_operator_sparsity (float, optional) – Weight for operator selection sparsity loss. Defaults to 1.0. Penalizes uncertain operator choices to encourage commitment. Higher = faster operator decision.
loss_weight_operator_l2 (float, optional) – Weight for L2 regularization on operator logits. Defaults to 0.0. Keeps logits bounded so tau can control commitment timing. Use > 0 when operators commit too early.
lr_permutation (float, optional) – Learning rate for permutation layer. Defaults to 0.3. Higher = faster exploration of feature orderings.
lr_transformation (float, optional) – Learning rate for transformation layer. Defaults to 0.5. Higher = faster transformation selection.
lr_aggregator (float, optional) – Learning rate for aggregator weights. Defaults to 0.1. Lower = more stable tree structure.
lr_other (float, optional) – Learning rate for other parameters. Defaults to 0.1.
use_class_weighting (bool, optional) – Whether to apply class weighting for imbalanced data. Defaults to True. When True, penalizes minority class errors more heavily (pos_weight = neg_count/pos_count). When False, uses standard BCE loss (original behavior).
full_tree_depth (int, optional) – Depth of the fully connected tree. Only used when tree_layout=”full”. Defaults to None (uses input_size - 1).
full_tree_shape (str, optional) – Shape of the fully connected tree. “triangle” (default) or “square”.
full_tree_temperature (float, optional) – Initial temperature for sigmoid edge weights. Defaults to 3.0.
full_tree_final_temperature (float, optional) – Final temperature after annealing. Defaults to 0.1.
full_tree_max_egress (int, optional) – Each source concentrates on top-K destinations (via loss). Defaults to None (no constraint).
loss_weight_full_tree_egress (float, optional) – Weight for full tree egress constraint loss. Defaults to 0.0.
loss_weight_full_tree_ingress (float, optional) – Weight for full tree ingress constraint loss (max 2 inputs per node). Defaults to 0.5.
use_permutation_layer (bool, optional) – Whether to use the permutation layer. Defaults to True. Set to False for full tree layout to let the tree learn input routing directly.
regression_loss_type (str, optional) – Loss type for regression mode. Defaults to “mse”. Supported values are “mse” (standard MSE, scale-sensitive), “correlation” (Pearson correlation loss, scale-invariant), and “normalized_mse” (z-score normalized MSE for scale-invariant pattern matching).
- evaluate(x, y, threshold=0.5)¶
- find_best_model(x, y, x_test, y_test, attempts=100, acceptance_threshold=0.95, save_path='./assembler.pth', max_epochs=12000, annealing_epochs=None, frozen_training_epochs=200, convergence_patience=500, convergence_delta=0.001, freeze_confidence_threshold=0.95, freeze_min_confidence=0.85, loss_weight_perm_sparsity=None, sparsity_schedule=None, freeze_aggregation_epochs=0, save_model=True, use_hierarchical_permutation=False, force_freeze=True, hierarchical_group_size=3, hierarchical_epochs_per_attempt=None, hierarchical_bleed_ratio=0.1, hierarchical_bleed_decay=2.0, sinkhorn_iters=100, binary_threshold=0.5, task_type='classification', operator_initial_tau=5.0, operator_final_tau=0.5, operator_freeze_min_confidence=0.7, operator_freeze_epochs=0, skip_frozen_threshold=0.99, full_tree_egress_warmup_epochs=0, full_tree_egress_ramp_epochs=0, full_tree_egress_start_metric=0.99, full_tree_egress_drop_tolerance=0.02, full_tree_egress_adapt_rate=0.2)¶
Find the best model by training multiple times and evaluating accuracy.
- Parameters:
x (torch.Tensor) – Input tensor for training.
y (torch.Tensor) – Target tensor for training.
x_test (torch.Tensor) – Input tensor for testing.
y_test (torch.Tensor) – Target tensor for testing.
attempts (int, optional) – Number of attempts to find the best model. Defaults to 100.
acceptance_threshold (float, optional) – Minimum accuracy to accept a model. Defaults to 0.95.
save_path (str, optional) – Path to save the best model. Defaults to “./assembler.pth”.
max_epochs (int, optional) – Maximum epochs for training (safety limit). Defaults to 12000.
annealing_epochs (int, optional) – Epochs for temperature annealing. Defaults to None.
frozen_training_epochs (int, optional) – Epochs to train after freezing. Defaults to 200.
convergence_patience (int, optional) – Epochs without improvement before considering converged. Defaults to 500.
convergence_delta (float, optional) – Minimum loss improvement to reset patience. Defaults to 0.001.
freeze_confidence_threshold (float, optional) – Mean max probability threshold for high-confidence freezing. Defaults to 0.95.
freeze_min_confidence (float, optional) – Minimum confidence for early freeze when combined with low loss. Defaults to 0.85 (raised from 0.75 to be more conservative).
sinkhorn_iters (int, optional) – Number of Sinkhorn normalization iterations for soft permutation convergence. Higher values improve doubly-stochastic property but increase compute time. Defaults to 100.
loss_weight_perm_sparsity (float, optional) – Weight for permutation sparsity loss (encourages peaked distributions). If None, uses instance default. Defaults to None.
sparsity_schedule (tuple, optional) – Dynamic sparsity weight scheduling as (initial_weight, final_weight, transition_epochs). Example: (10.0, 0.1, 1000) starts with high sparsity emphasis (10.0) and linearly decreases to 0.1 over 1000 epochs. Defaults to None (uses constant loss_weight_perm_sparsity).
freeze_aggregation_epochs (int, optional) – Freeze aggregation parameters for first N epochs, allowing only permutation to learn. Useful for giving permutation undiluted classification signal. Defaults to 0 (no freezing).
save_model (bool, optional) – Whether to save the best model. Defaults to True.
use_hierarchical_permutation (bool, optional) – Use coarse-grained permutation exploration. Defaults to False.
hierarchical_group_size (int, optional) – Group size for hierarchical permutation (e.g., 3 for 10 inputs → 4x4 coarse matrix). Defaults to 3.
hierarchical_epochs_per_attempt (int, optional) – Epochs to run for each coarse permutation. If None, uses max_epochs. Defaults to None.
hierarchical_bleed_ratio (float, optional) – Ratio of std for adjacent blocks (0.0=hard blocks, 0.1=10% bleed, 1.0=full bleed). Defaults to 0.1.
hierarchical_bleed_decay (float, optional) – How quickly bleeding decays with distance (higher=faster decay). Defaults to 2.0.
operator_freeze_min_confidence (float, optional) – Minimum average operator selection confidence required before freezing permutation. Blocks freeze until operators commit to their choices. 0.0 disables the requirement (legacy behavior), 0.7 requires 70% average operator confidence. Defaults to 0.7.
operator_freeze_epochs (int, optional) – Freeze operator selection for first N epochs, allowing only edge/routing to learn. This decouples structure discovery from operator selection. After N epochs, operators unfreeze and start learning. Defaults to 0 (no operator freeze).
skip_frozen_threshold (float, optional) – Minimum metric required to skip frozen training after freeze. Only skips frozen training if freeze improves metric AND we’re above this threshold. For regression tasks, this should be very high (e.g., 0.99) to ensure weights fully converge. Defaults to 0.99.
full_tree_egress_warmup_epochs (int, optional) – In full-tree mode, keep egress concentration disabled for the first N epochs so structure can be learned first. Defaults to 0.
full_tree_egress_ramp_epochs (int, optional) – Number of epochs to ramp egress loss weight from 0 to target after warmup. Defaults to 0 (immediate application after warmup).
full_tree_egress_start_metric (float, optional) – Minimum training metric required before egress concentration begins. This lets the full tree learn the task first before sparsifying edges. Defaults to 0.99.
full_tree_egress_drop_tolerance (float, optional) – Maximum allowed training-metric drop from warmup baseline before reducing egress pressure. Defaults to 0.02.
full_tree_egress_adapt_rate (float, optional) – Fractional backoff/adjustment rate for dynamic egress weight when metric drop exceeds tolerance. Defaults to 0.2.
- Returns:
Best model state dictionary and its metric (accuracy or R²).
- Return type:
tuple
- forward(x)¶
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- inference(x, threshold=0.5)¶
- inference_raw(x)¶
- load_model(filepath)¶
- make_param_groups()¶
- prune_features(features)¶
- save_model(filepath)¶
- train_model(x, y, epochs)¶
binaryTreeLogicNet¶
- class bacon.binaryTreeLogicNet.binaryTreeLogicNet(input_size, weight_mode='trainable', weight_normalization='minmax', weight_value=0.5, weight_range=(0.0, 1.0), weight_choices=None, noise_increase=1.05, noise_decrease=0.95, loss_amplifier=1.0, normalize_andness=True, min_noise=0.0, max_noise=2.0, is_frozen=False, tree_layout='left', weight_penalty_strength=0.001, aggregator='lsp.full_weight', early_stop_patience=10, early_stop_min_delta=0.0001, early_stop_threshold=0.01, use_transformation_layer: bool = False, transformation_temperature: float = 1.0, transformation_use_gumbel: bool = False, transformations=None, device=None, sinkhorn_iters=100, full_tree_depth: int = None, full_tree_shape: str = 'triangle', full_tree_temperature: float = 3.0, full_tree_final_temperature: float = 0.1, full_tree_max_egress: int = None, full_tree_concentrate_ingress: bool = False, full_tree_use_sinkhorn: bool = False, alternating_learn_first_routing: bool = True, alternating_learn_subsequent_routing: bool = True, alternating_learn_exponents: bool = False, alternating_min_exponent: float = 1.0, alternating_max_exponent: float = 2.0, alternating_max_egress: int = 1, alternating_use_straight_through: bool = True, alternating_balance_weight: float = 50.0, alternating_egress_weight: float = 0.5, use_constant_input: bool = False, use_permutation_layer: bool = True)¶
Bases:
ModuleRepresents a binary tree logic network for interpretable decision-making using graded logic.
- Parameters:
input_size (int) – Number of input features.
weight_mode (str, optional) – Mode for weight configuration. Defaults to “trainable”.
weight_value (float, optional) – Initial value for fixed weights. Defaults to 0.5.
weight_range (tuple, optional) – Range for random weights. Defaults to (0.0, 1.0).
weight_choices (list, optional) – Choices for discrete weights. Defaults to None.
noise_increase (float, optional) – Factor to increase noise. Defaults to 1.05.
noise_decrease (float, optional) – Factor to decrease noise. Defaults to 0.95.
loss_amplifier (float, optional) – Amplifier for the loss. Defaults to 1000.0.
min_noise (float, optional) – Minimum noise level. Defaults to 0.0.
max_noise (float, optional) – Maximum noise level. Defaults to 2.0.
is_frozen (bool, optional) – Whether to freeze the structure. Defaults to False.
tree_layout (str, optional) – Layout of the tree. Defaults to “left”. Options: “left”, “balanced”, “paired”, “full”, “alternating”.
weight_penalty_strength (float, optional) – Penalty strength on weights. Defaults to 1e-3. A strong penalty leads to more balaned weights (closer to 0.5).
aggregator (callable, optional) – Aggregator to be used. Defaults to “lsp.full_weight”.
device (torch.device, optional) – Device to run the model on. Defaults to None (uses CUDA if available).
full_tree_depth (int, optional) – Depth of the fully connected tree. Only used when tree_layout=”full”. Defaults to None (uses input_size - 1).
full_tree_shape (str, optional) – Shape of the fully connected tree. “triangle” (default) or “square”.
full_tree_temperature (float, optional) – Initial temperature for sigmoid edge weights. Defaults to 3.0.
full_tree_final_temperature (float, optional) – Final temperature after annealing. Defaults to 0.1.
full_tree_max_egress (int, optional) – Each source concentrates on top-K destinations (via loss). Defaults to None (no constraint).
use_permutation_layer (bool, optional) – Whether to use the permutation layer. Defaults to True. Set to False for full tree layout to let the tree learn input routing directly.
- anneal_alternating_tree_gumbel(progress: float, initial: float = 1.0, final: float = 0.0) None¶
Anneal the Gumbel noise scale of the alternating tree.
- Parameters:
progress – Training progress from 0.0 to 1.0
initial – Initial noise scale
final – Final noise scale
- anneal_alternating_tree_temperature(progress: float) None¶
Anneal the temperature of the alternating tree.
- Parameters:
progress – Training progress from 0.0 to 1.0
- anneal_full_tree_gumbel(progress: float, initial: float = 1.0, final: float = 0.0) None¶
Anneal the Gumbel noise scale of the fully connected tree.
- Parameters:
progress – Training progress from 0.0 to 1.0
initial – Initial noise scale
final – Final noise scale
- anneal_full_tree_temperature(progress: float) None¶
Anneal the temperature of the fully connected tree.
- Parameters:
progress – Training progress from 0.0 to 1.0
- build_balanced_tree(node_outputs, weights, biases)¶
Build a balanced binary tree from node outputs.
- Parameters:
node_outputs (list) – List of node outputs.
weights (list) – List of weights for the nodes.
biases (list) – List of biases for the nodes.
- Returns:
Final output of the balanced tree.
- Return type:
torch.Tensor
- build_paired_tree(node_outputs, weights, biases)¶
Build a two-stage ‘paired’ tree: first pair inputs (0,1), (2,3), … then fold pairs.
- Parameters:
node_outputs (list) – Leaf outputs.
weights (list) – Aggregator weights per node.
biases (list) – Aggregator biases per node.
- Returns:
Final aggregated output.
- Return type:
torch.Tensor
- forward(x, targets=None)¶
Forward pass through the binary tree logic network.
- Parameters:
x (torch.Tensor) – Input tensor of shape (batch_size, input_size).
targets (torch.Tensor|None) – Optional target tensor [batch,1] in {0,1}; if provided and auto_refine is enabled during training, the model may run a light-weight permutation search and freeze the best permutation.
- Returns:
Output tensor of shape (batch_size, 1).
- Return type:
torch.Tensor
- get_alternating_tree_balance_loss() Tensor¶
Get balance loss for alternating tree routing.
Encourages balanced distribution of inputs across destinations. Returns 0 if not using “alternating” layout.
- get_alternating_tree_egress_loss() Tensor¶
Get egress loss for alternating tree routing.
Encourages peaked row distributions (clear winner per source). Returns 0 if not using “alternating” layout.
- get_alternating_tree_exponent_regularization_loss() Tensor¶
Get exponent regularization loss for alternating coefficient layers.
- get_alternating_tree_num_nodes() int¶
Get number of aggregation nodes in alternating tree.
Returns 0 if not using “alternating” layout.
- get_alternating_tree_structure_description() str¶
Get human-readable structure description for alternating tree.
Returns empty string if not using “alternating” layout.
- get_full_tree_confidence() float¶
Get confidence score for fully connected tree edge selections.
Higher values indicate more peaked edge weight distributions. Returns 0 if not using “full” layout.
- get_full_tree_egress_loss() Tensor¶
Get egress constraint loss for fully connected tree.
Encourages each source node to concentrate outgoing edges to top-K destinations, where K is controlled by full_tree_max_egress parameter. Returns 0 if not using “full” layout or max_egress is None.
- get_full_tree_ingress_balance_loss() Tensor¶
Get ingress balance loss for fully connected tree.
Encourages balanced distribution of inputs across destinations. Prevents all sources from routing to the same destination. Returns 0 if not using “full” layout.
- get_full_tree_ingress_loss() Tensor¶
Get ingress constraint loss for fully connected tree.
Discourages each destination node from receiving more than 2 inputs. Binary operators (add, mul) work best with exactly 2 inputs. Returns 0 if not using “full” layout.
- get_full_tree_scale_regularization_loss() Tensor¶
Get scale regularization loss for fully connected tree.
Penalizes extreme scale coefficients to prevent coefficient hacking. Returns 0 if not using “full” layout.
- get_full_tree_structure() dict¶
Get the learned structure of the fully connected tree.
Returns dictionary with layer widths, significant edges, and biases. Returns empty dict if not using “full” layout.
- get_input_labels(variable_names=None)¶
- harden_alternating_tree() None¶
Harden the alternating tree to discrete edge selections.
- harden_full_tree(mode: str = 'argmax') None¶
Harden the fully connected tree to discrete edge selections.
- Parameters:
mode – Hardening mode. - “argmax”: destination-wise argmax by default; if max_egress==1, uses row-wise argmax. - “auto”: row-wise argmax when max_egress==1, otherwise “smart”. - Other modes are forwarded to FullyConnectedTree.harden(…).
- load_model(file_name)¶
Load the model state from a file.
- Parameters:
file_name (str) – Path to load the model from.
- Raises:
ValueError – If there’s an architecture mismatch that would cause random reinitialization.
- prune_features(feature_index)¶
Prune a single feature by adjusting its corresponding aggregator weight.
In a left-associative tree: - Feature 0 (left input of agg 0) - Feature 1 (right input of agg 0) - Feature 2 (right input of agg 1) - uses aggregator at index 1 - Feature k (k>=2) uses aggregator at index k-1
This is designed to be called incrementally from outside to build up cumulative pruning. Does NOT clear existing pruning state - adds to it.
- Parameters:
feature_index (int) – Index of the feature to prune (0 to num_leaves-1).
- reset_optimizer(learning_rate=0.2)¶
Reset the optimizer for the model. :param learning_rate: Learning rate for the optimizer. Defaults to 0.2. :type learning_rate: float, optional
- save_model(file_name)¶
Save the model state to a file.
- Parameters:
file_name (str) – Path to save the model.
- train_model(x, y, epochs, is_frozen)¶
Train the binary tree logic network.
- Parameters:
x (torch.Tensor) – Input tensor of shape (batch_size, input_size).
y (torch.Tensor) – Target tensor of shape (batch_size, 1).
epochs (int) – Number of training epochs.
- Returns:
History of loss values during training.
- Return type:
list
- unharden_alternating_tree() None¶
Revert alternating tree hardening to allow continued training.
- unharden_full_tree() None¶
Revert full tree hardening to allow continued training.