BACON Module

class bacon.FullWeightAggregator(*args, **kwargs)

Bases: AggregatorBase

aggregate_float(values: Sequence[float], a: float, weights: Sequence[float]) float

Aggregate N scalar values with scalar andness a and N weights.

aggregate_tensor(values: Sequence[Any], andness, weights)

Aggregates two tensors using the Full Weight method.

Parameters:
  • x1 (torch.Tensor) – First input tensor.

  • x2 (torch.Tensor) – Second input tensor.

  • andness (float) – Andness.

  • w0 (float) – Weight for the first tensor.

  • w1 (float) – Weight for the second tensor.

Returns:

Resulting tensor after aggregation.

Return type:

torch.Tensor

class bacon.HalfWeightAggregator(*args, **kwargs)

Bases: AggregatorBase

aggregate_float(values: Sequence[float], a: float, weights: Sequence[float]) float

Aggregate N scalar values with scalar andness a and N weights.

aggregate_tensor(values: Sequence[Any], andness, weights) Any

Aggregate N tensors with andness a and N weights.

r(a)
class bacon.MinMaxAggregator(*args, **kwargs)

Bases: AggregatorBase

Aggregate a sequence of inputs with min/max interpolation.

The aggregator computes:

  • elementwise minimum when andness is near 1

  • elementwise maximum when andness is near 0

with a straight-through hard gating step so gradients can still flow through the continuous proxy during training.

aggregate_float(values: Sequence[float], a: float, weights: Sequence[float]) float

Float-friendly wrapper around aggregate_tensor().

Parameters:
  • values – Input scalar values to aggregate.

  • a – Andness control value.

  • weights – Optional per-input weights used as soft gates.

Returns:

Python float result.

aggregate_tensor(values: Sequence[Any], andness, weights=None)

Aggregate one or more tensor-like values.

Parameters:
  • values – Non-empty sequence of tensors with compatible shapes.

  • andness – Control signal where larger values bias toward min/AND and smaller values bias toward max/OR.

  • weights – Optional per-input gates in [0, 1]. When provided, each input is blended with a neutral baseline before reduction.

Returns:

Tensor with the same broadcast-compatible shape as each input.

Raises:

ValueError – If values is empty.

baconNet


class bacon.baconNet.TrainingSetup(optimizer: Optimizer, criterion: Module, pos_weight: float | None, param_groups: list, aggregation_frozen: bool, actual_max_epochs: int, use_temperature_annealing: bool, perm_temp_decay_rate: float | None, trans_temp_decay_rate: float | None, operator_tau_decay_rate: float | None, operator_initial_tau: float | None, anneal_over_epochs: int, original_sparsity_weight: float | None, task_type: str, loss_history: list, accuracy_history: list, freeze_confidence_warning_shown: bool, best_loss_for_convergence: float, epochs_without_improvement: int, epoch_when_frozen: int | None, has_converged_before_freeze: bool, best_confidence: float, epochs_without_confidence_improvement: int, best_loss: float, epochs_since_improvement: int, temp_paused: bool, transformation_converged: bool, best_frozen_loss: float | None, best_frozen_state: dict | None, frozen_lr_reduced: bool, best_overall_loss: float | None, best_overall_state: dict | None, best_overall_epoch: int | None)

Bases: object

Encapsulates all training setup state for a single attempt.

accuracy_history: list
actual_max_epochs: int
aggregation_frozen: bool
anneal_over_epochs: int
best_confidence: float
best_frozen_loss: float | None
best_frozen_state: dict | None
best_loss: float
best_loss_for_convergence: float
best_overall_epoch: int | None
best_overall_loss: float | None
best_overall_state: dict | None
criterion: Module
epoch_when_frozen: int | None
epochs_since_improvement: int
epochs_without_confidence_improvement: int
epochs_without_improvement: int
freeze_confidence_warning_shown: bool
frozen_lr_reduced: bool
has_converged_before_freeze: bool
loss_history: list
operator_initial_tau: float | None
operator_tau_decay_rate: float | None
optimizer: Optimizer
original_sparsity_weight: float | None
param_groups: list
perm_temp_decay_rate: float | None
pos_weight: float | None
task_type: str
temp_paused: bool
trans_temp_decay_rate: float | None
transformation_converged: bool
use_temperature_annealing: bool
class bacon.baconNet.baconNet(input_size, tree_layout='left', loss_amplifier=1, weight_penalty_strength=0.001, weight_mode='trainable', weight_normalization='minmax', aggregator='lsp.full_weight', normalize_andness=True, is_frozen=False, use_transformation_layer=False, transformation_temperature=None, transformation_use_gumbel=False, transformations=None, early_stop_threshold_large_inputs=0.1, permutation_initial_temperature=5.0, permutation_final_temperature=0.1, transformation_initial_temperature=1.0, transformation_final_temperature=0.1, loss_weight_main=1.0, loss_weight_perm_entropy=0.0, loss_weight_trans_entropy=0.0, loss_weight_perm_sparsity=0.01, loss_weight_operator_sparsity=1.0, loss_weight_operator_l2=0.0, lr_permutation=0.3, lr_transformation=0.5, lr_aggregator=0.1, lr_other=0.1, use_class_weighting=True, loss_trim_percentile: float = 0.0, loss_trim_mode: str = 'drop_high', loss_trim_start_epoch: int = 0, training_policy=None, full_tree_depth: int = None, full_tree_shape: str = 'triangle', full_tree_temperature: float = 3.0, full_tree_final_temperature: float = 0.1, full_tree_max_egress: int = None, full_tree_concentrate_ingress: bool = False, full_tree_use_sinkhorn: bool = False, loss_weight_full_tree_egress: float = 0.0, loss_weight_full_tree_ingress: float = 0.5, loss_weight_full_tree_ingress_balance: float = 0.0, loss_weight_full_tree_scale_reg: float = 0.0, alternating_learn_first_routing: bool = True, alternating_learn_subsequent_routing: bool = True, alternating_learn_exponents: bool = False, alternating_min_exponent: float = 1.0, alternating_max_exponent: float = 2.0, alternating_max_egress: int = 1, alternating_use_straight_through: bool = True, loss_weight_alternating_balance: float = 50.0, loss_weight_alternating_egress: float = 0.5, loss_weight_alternating_exponent_reg: float = 0.0, use_constant_input: bool = False, use_permutation_layer: bool = True, regression_loss_type: str = 'mse')

Bases: Module

Represents a BACON network for interpretable decision-making using graded logic.

Parameters:
  • input_size (int) – Number of input features. This is likely to be removed in the future.

  • tree_layout (str, optional) – Layout of the tree. Defaults to “left”. Other layouts are not supported yet.

  • loss_amplifier (float, optional) – Amplifier for the loss. Defaults to 1.

  • weight_penalty_strength (float, optional) – Penalty strength on weights. Defaults to 1e-3. A strong penalty leads to more balaned weights (closer to 0.5).

  • normalize_andness (bool, optional) – Whether to normalize andness. Defaults to True. This should set to False if the chosen aggregator, such as bool.min_max, already normalizes the andness.

  • weight_mode (str, optional) – Mode for weight configuration. Defaults to “trainable”. Use “fixed” for fixed weights (set to 0.5).

  • aggregator (str, optional) – Aggregator to be used. Defaults to “lsp.full_weight”.

  • is_frozen (bool, optional) – Whether to freeze the structure. Defaults to False.

  • early_stop_threshold_large_inputs (float, optional) – Early stop threshold for transformation layers with 10+ inputs. Defaults to 0.1. Lower values require more training but achieve higher accuracy.

  • transformations (list, optional) – List of transformation objects to use. If None, uses all 6 default transformations. Example: [IdentityTransformation(n), NegationTransformation(n)] for identity+negation only.

  • permutation_initial_temperature (float, optional) – Starting temperature for permutation annealing. Defaults to 5.0. Higher = more initial exploration.

  • permutation_final_temperature (float, optional) – Final temperature for permutation annealing. Defaults to 0.1. Lower = harder final permutation.

  • transformation_initial_temperature (float, optional) – Starting temperature for transformation layer. Defaults to 1.0. Should be lower than permutation since transformation is simpler (2^n vs n! states).

  • transformation_final_temperature (float, optional) – Final temperature for transformation layer. Defaults to 0.1. Same as permutation final temp.

  • loss_weight_main (float, optional) – Weight for main BCE loss. Defaults to 1.0.

  • loss_weight_perm_entropy (float, optional) – Weight for permutation entropy regularization. Defaults to 0.0. Higher = encourage exploration. Typical range: 0.0-0.1.

  • loss_weight_trans_entropy (float, optional) – Weight for transformation entropy regularization. Defaults to 0.0. Higher = encourage decisive transformation selection. Typical range: 0.0-0.1.

  • loss_weight_perm_sparsity (float, optional) – Weight for permutation sparsity loss. Defaults to 0.01. Penalizes high entropy (flat distributions) to encourage peaked/sparse permutations. Higher = stronger push toward confident or clear multi-modal distributions. Typical range: 0.0-0.1.

  • loss_weight_operator_sparsity (float, optional) – Weight for operator selection sparsity loss. Defaults to 1.0. Penalizes uncertain operator choices to encourage commitment. Higher = faster operator decision.

  • loss_weight_operator_l2 (float, optional) – Weight for L2 regularization on operator logits. Defaults to 0.0. Keeps logits bounded so tau can control commitment timing. Use > 0 when operators commit too early.

  • lr_permutation (float, optional) – Learning rate for permutation layer. Defaults to 0.3. Higher = faster exploration of feature orderings.

  • lr_transformation (float, optional) – Learning rate for transformation layer. Defaults to 0.5. Higher = faster transformation selection.

  • lr_aggregator (float, optional) – Learning rate for aggregator weights. Defaults to 0.1. Lower = more stable tree structure.

  • lr_other (float, optional) – Learning rate for other parameters. Defaults to 0.1.

  • use_class_weighting (bool, optional) – Whether to apply class weighting for imbalanced data. Defaults to True. When True, penalizes minority class errors more heavily (pos_weight = neg_count/pos_count). When False, uses standard BCE loss (original behavior).

  • full_tree_depth (int, optional) – Depth of the fully connected tree. Only used when tree_layout=”full”. Defaults to None (uses input_size - 1).

  • full_tree_shape (str, optional) – Shape of the fully connected tree. “triangle” (default) or “square”.

  • full_tree_temperature (float, optional) – Initial temperature for sigmoid edge weights. Defaults to 3.0.

  • full_tree_final_temperature (float, optional) – Final temperature after annealing. Defaults to 0.1.

  • full_tree_max_egress (int, optional) – Each source concentrates on top-K destinations (via loss). Defaults to None (no constraint).

  • loss_weight_full_tree_egress (float, optional) – Weight for full tree egress constraint loss. Defaults to 0.0.

  • loss_weight_full_tree_ingress (float, optional) – Weight for full tree ingress constraint loss (max 2 inputs per node). Defaults to 0.5.

  • use_permutation_layer (bool, optional) – Whether to use the permutation layer. Defaults to True. Set to False for full tree layout to let the tree learn input routing directly.

  • regression_loss_type (str, optional) – Loss type for regression mode. Defaults to “mse”. Supported values are “mse” (standard MSE, scale-sensitive), “correlation” (Pearson correlation loss, scale-invariant), and “normalized_mse” (z-score normalized MSE for scale-invariant pattern matching).

evaluate(x, y, threshold=0.5)
find_best_model(x, y, x_test, y_test, attempts=100, acceptance_threshold=0.95, save_path='./assembler.pth', max_epochs=12000, annealing_epochs=None, frozen_training_epochs=200, convergence_patience=500, convergence_delta=0.001, freeze_confidence_threshold=0.95, freeze_min_confidence=0.85, loss_weight_perm_sparsity=None, sparsity_schedule=None, freeze_aggregation_epochs=0, save_model=True, use_hierarchical_permutation=False, force_freeze=True, hierarchical_group_size=3, hierarchical_epochs_per_attempt=None, hierarchical_bleed_ratio=0.1, hierarchical_bleed_decay=2.0, sinkhorn_iters=100, binary_threshold=0.5, task_type='classification', operator_initial_tau=5.0, operator_final_tau=0.5, operator_freeze_min_confidence=0.7, operator_freeze_epochs=0, skip_frozen_threshold=0.99, full_tree_egress_warmup_epochs=0, full_tree_egress_ramp_epochs=0, full_tree_egress_start_metric=0.99, full_tree_egress_drop_tolerance=0.02, full_tree_egress_adapt_rate=0.2)

Find the best model by training multiple times and evaluating accuracy.

Parameters:
  • x (torch.Tensor) – Input tensor for training.

  • y (torch.Tensor) – Target tensor for training.

  • x_test (torch.Tensor) – Input tensor for testing.

  • y_test (torch.Tensor) – Target tensor for testing.

  • attempts (int, optional) – Number of attempts to find the best model. Defaults to 100.

  • acceptance_threshold (float, optional) – Minimum accuracy to accept a model. Defaults to 0.95.

  • save_path (str, optional) – Path to save the best model. Defaults to “./assembler.pth”.

  • max_epochs (int, optional) – Maximum epochs for training (safety limit). Defaults to 12000.

  • annealing_epochs (int, optional) – Epochs for temperature annealing. Defaults to None.

  • frozen_training_epochs (int, optional) – Epochs to train after freezing. Defaults to 200.

  • convergence_patience (int, optional) – Epochs without improvement before considering converged. Defaults to 500.

  • convergence_delta (float, optional) – Minimum loss improvement to reset patience. Defaults to 0.001.

  • freeze_confidence_threshold (float, optional) – Mean max probability threshold for high-confidence freezing. Defaults to 0.95.

  • freeze_min_confidence (float, optional) – Minimum confidence for early freeze when combined with low loss. Defaults to 0.85 (raised from 0.75 to be more conservative).

  • sinkhorn_iters (int, optional) – Number of Sinkhorn normalization iterations for soft permutation convergence. Higher values improve doubly-stochastic property but increase compute time. Defaults to 100.

  • loss_weight_perm_sparsity (float, optional) – Weight for permutation sparsity loss (encourages peaked distributions). If None, uses instance default. Defaults to None.

  • sparsity_schedule (tuple, optional) – Dynamic sparsity weight scheduling as (initial_weight, final_weight, transition_epochs). Example: (10.0, 0.1, 1000) starts with high sparsity emphasis (10.0) and linearly decreases to 0.1 over 1000 epochs. Defaults to None (uses constant loss_weight_perm_sparsity).

  • freeze_aggregation_epochs (int, optional) – Freeze aggregation parameters for first N epochs, allowing only permutation to learn. Useful for giving permutation undiluted classification signal. Defaults to 0 (no freezing).

  • save_model (bool, optional) – Whether to save the best model. Defaults to True.

  • use_hierarchical_permutation (bool, optional) – Use coarse-grained permutation exploration. Defaults to False.

  • hierarchical_group_size (int, optional) – Group size for hierarchical permutation (e.g., 3 for 10 inputs → 4x4 coarse matrix). Defaults to 3.

  • hierarchical_epochs_per_attempt (int, optional) – Epochs to run for each coarse permutation. If None, uses max_epochs. Defaults to None.

  • hierarchical_bleed_ratio (float, optional) – Ratio of std for adjacent blocks (0.0=hard blocks, 0.1=10% bleed, 1.0=full bleed). Defaults to 0.1.

  • hierarchical_bleed_decay (float, optional) – How quickly bleeding decays with distance (higher=faster decay). Defaults to 2.0.

  • operator_freeze_min_confidence (float, optional) – Minimum average operator selection confidence required before freezing permutation. Blocks freeze until operators commit to their choices. 0.0 disables the requirement (legacy behavior), 0.7 requires 70% average operator confidence. Defaults to 0.7.

  • operator_freeze_epochs (int, optional) – Freeze operator selection for first N epochs, allowing only edge/routing to learn. This decouples structure discovery from operator selection. After N epochs, operators unfreeze and start learning. Defaults to 0 (no operator freeze).

  • skip_frozen_threshold (float, optional) – Minimum metric required to skip frozen training after freeze. Only skips frozen training if freeze improves metric AND we’re above this threshold. For regression tasks, this should be very high (e.g., 0.99) to ensure weights fully converge. Defaults to 0.99.

  • full_tree_egress_warmup_epochs (int, optional) – In full-tree mode, keep egress concentration disabled for the first N epochs so structure can be learned first. Defaults to 0.

  • full_tree_egress_ramp_epochs (int, optional) – Number of epochs to ramp egress loss weight from 0 to target after warmup. Defaults to 0 (immediate application after warmup).

  • full_tree_egress_start_metric (float, optional) – Minimum training metric required before egress concentration begins. This lets the full tree learn the task first before sparsifying edges. Defaults to 0.99.

  • full_tree_egress_drop_tolerance (float, optional) – Maximum allowed training-metric drop from warmup baseline before reducing egress pressure. Defaults to 0.02.

  • full_tree_egress_adapt_rate (float, optional) – Fractional backoff/adjustment rate for dynamic egress weight when metric drop exceeds tolerance. Defaults to 0.2.

Returns:

Best model state dictionary and its metric (accuracy or R²).

Return type:

tuple

forward(x)

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

inference(x, threshold=0.5)
inference_raw(x)
load_model(filepath)
make_param_groups()
prune_features(features)
save_model(filepath)
train_model(x, y, epochs)

binaryTreeLogicNet


class bacon.binaryTreeLogicNet.binaryTreeLogicNet(input_size, weight_mode='trainable', weight_normalization='minmax', weight_value=0.5, weight_range=(0.0, 1.0), weight_choices=None, noise_increase=1.05, noise_decrease=0.95, loss_amplifier=1.0, normalize_andness=True, min_noise=0.0, max_noise=2.0, is_frozen=False, tree_layout='left', weight_penalty_strength=0.001, aggregator='lsp.full_weight', early_stop_patience=10, early_stop_min_delta=0.0001, early_stop_threshold=0.01, use_transformation_layer: bool = False, transformation_temperature: float = 1.0, transformation_use_gumbel: bool = False, transformations=None, device=None, sinkhorn_iters=100, full_tree_depth: int = None, full_tree_shape: str = 'triangle', full_tree_temperature: float = 3.0, full_tree_final_temperature: float = 0.1, full_tree_max_egress: int = None, full_tree_concentrate_ingress: bool = False, full_tree_use_sinkhorn: bool = False, alternating_learn_first_routing: bool = True, alternating_learn_subsequent_routing: bool = True, alternating_learn_exponents: bool = False, alternating_min_exponent: float = 1.0, alternating_max_exponent: float = 2.0, alternating_max_egress: int = 1, alternating_use_straight_through: bool = True, alternating_balance_weight: float = 50.0, alternating_egress_weight: float = 0.5, use_constant_input: bool = False, use_permutation_layer: bool = True)

Bases: Module

Represents a binary tree logic network for interpretable decision-making using graded logic.

Parameters:
  • input_size (int) – Number of input features.

  • weight_mode (str, optional) – Mode for weight configuration. Defaults to “trainable”.

  • weight_value (float, optional) – Initial value for fixed weights. Defaults to 0.5.

  • weight_range (tuple, optional) – Range for random weights. Defaults to (0.0, 1.0).

  • weight_choices (list, optional) – Choices for discrete weights. Defaults to None.

  • noise_increase (float, optional) – Factor to increase noise. Defaults to 1.05.

  • noise_decrease (float, optional) – Factor to decrease noise. Defaults to 0.95.

  • loss_amplifier (float, optional) – Amplifier for the loss. Defaults to 1000.0.

  • min_noise (float, optional) – Minimum noise level. Defaults to 0.0.

  • max_noise (float, optional) – Maximum noise level. Defaults to 2.0.

  • is_frozen (bool, optional) – Whether to freeze the structure. Defaults to False.

  • tree_layout (str, optional) – Layout of the tree. Defaults to “left”. Options: “left”, “balanced”, “paired”, “full”, “alternating”.

  • weight_penalty_strength (float, optional) – Penalty strength on weights. Defaults to 1e-3. A strong penalty leads to more balaned weights (closer to 0.5).

  • aggregator (callable, optional) – Aggregator to be used. Defaults to “lsp.full_weight”.

  • device (torch.device, optional) – Device to run the model on. Defaults to None (uses CUDA if available).

  • full_tree_depth (int, optional) – Depth of the fully connected tree. Only used when tree_layout=”full”. Defaults to None (uses input_size - 1).

  • full_tree_shape (str, optional) – Shape of the fully connected tree. “triangle” (default) or “square”.

  • full_tree_temperature (float, optional) – Initial temperature for sigmoid edge weights. Defaults to 3.0.

  • full_tree_final_temperature (float, optional) – Final temperature after annealing. Defaults to 0.1.

  • full_tree_max_egress (int, optional) – Each source concentrates on top-K destinations (via loss). Defaults to None (no constraint).

  • use_permutation_layer (bool, optional) – Whether to use the permutation layer. Defaults to True. Set to False for full tree layout to let the tree learn input routing directly.

anneal_alternating_tree_gumbel(progress: float, initial: float = 1.0, final: float = 0.0) None

Anneal the Gumbel noise scale of the alternating tree.

Parameters:
  • progress – Training progress from 0.0 to 1.0

  • initial – Initial noise scale

  • final – Final noise scale

anneal_alternating_tree_temperature(progress: float) None

Anneal the temperature of the alternating tree.

Parameters:

progress – Training progress from 0.0 to 1.0

anneal_full_tree_gumbel(progress: float, initial: float = 1.0, final: float = 0.0) None

Anneal the Gumbel noise scale of the fully connected tree.

Parameters:
  • progress – Training progress from 0.0 to 1.0

  • initial – Initial noise scale

  • final – Final noise scale

anneal_full_tree_temperature(progress: float) None

Anneal the temperature of the fully connected tree.

Parameters:

progress – Training progress from 0.0 to 1.0

build_balanced_tree(node_outputs, weights, biases)

Build a balanced binary tree from node outputs.

Parameters:
  • node_outputs (list) – List of node outputs.

  • weights (list) – List of weights for the nodes.

  • biases (list) – List of biases for the nodes.

Returns:

Final output of the balanced tree.

Return type:

torch.Tensor

build_paired_tree(node_outputs, weights, biases)

Build a two-stage ‘paired’ tree: first pair inputs (0,1), (2,3), … then fold pairs.

Parameters:
  • node_outputs (list) – Leaf outputs.

  • weights (list) – Aggregator weights per node.

  • biases (list) – Aggregator biases per node.

Returns:

Final aggregated output.

Return type:

torch.Tensor

forward(x, targets=None)

Forward pass through the binary tree logic network.

Parameters:
  • x (torch.Tensor) – Input tensor of shape (batch_size, input_size).

  • targets (torch.Tensor|None) – Optional target tensor [batch,1] in {0,1}; if provided and auto_refine is enabled during training, the model may run a light-weight permutation search and freeze the best permutation.

Returns:

Output tensor of shape (batch_size, 1).

Return type:

torch.Tensor

get_alternating_tree_balance_loss() Tensor

Get balance loss for alternating tree routing.

Encourages balanced distribution of inputs across destinations. Returns 0 if not using “alternating” layout.

get_alternating_tree_egress_loss() Tensor

Get egress loss for alternating tree routing.

Encourages peaked row distributions (clear winner per source). Returns 0 if not using “alternating” layout.

get_alternating_tree_exponent_regularization_loss() Tensor

Get exponent regularization loss for alternating coefficient layers.

get_alternating_tree_num_nodes() int

Get number of aggregation nodes in alternating tree.

Returns 0 if not using “alternating” layout.

get_alternating_tree_structure_description() str

Get human-readable structure description for alternating tree.

Returns empty string if not using “alternating” layout.

get_full_tree_confidence() float

Get confidence score for fully connected tree edge selections.

Higher values indicate more peaked edge weight distributions. Returns 0 if not using “full” layout.

get_full_tree_egress_loss() Tensor

Get egress constraint loss for fully connected tree.

Encourages each source node to concentrate outgoing edges to top-K destinations, where K is controlled by full_tree_max_egress parameter. Returns 0 if not using “full” layout or max_egress is None.

get_full_tree_ingress_balance_loss() Tensor

Get ingress balance loss for fully connected tree.

Encourages balanced distribution of inputs across destinations. Prevents all sources from routing to the same destination. Returns 0 if not using “full” layout.

get_full_tree_ingress_loss() Tensor

Get ingress constraint loss for fully connected tree.

Discourages each destination node from receiving more than 2 inputs. Binary operators (add, mul) work best with exactly 2 inputs. Returns 0 if not using “full” layout.

get_full_tree_scale_regularization_loss() Tensor

Get scale regularization loss for fully connected tree.

Penalizes extreme scale coefficients to prevent coefficient hacking. Returns 0 if not using “full” layout.

get_full_tree_structure() dict

Get the learned structure of the fully connected tree.

Returns dictionary with layer widths, significant edges, and biases. Returns empty dict if not using “full” layout.

get_input_labels(variable_names=None)
harden_alternating_tree() None

Harden the alternating tree to discrete edge selections.

harden_full_tree(mode: str = 'argmax') None

Harden the fully connected tree to discrete edge selections.

Parameters:

mode – Hardening mode. - “argmax”: destination-wise argmax by default; if max_egress==1, uses row-wise argmax. - “auto”: row-wise argmax when max_egress==1, otherwise “smart”. - Other modes are forwarded to FullyConnectedTree.harden(…).

load_model(file_name)

Load the model state from a file.

Parameters:

file_name (str) – Path to load the model from.

Raises:

ValueError – If there’s an architecture mismatch that would cause random reinitialization.

prune_features(feature_index)

Prune a single feature by adjusting its corresponding aggregator weight.

In a left-associative tree: - Feature 0 (left input of agg 0) - Feature 1 (right input of agg 0) - Feature 2 (right input of agg 1) - uses aggregator at index 1 - Feature k (k>=2) uses aggregator at index k-1

This is designed to be called incrementally from outside to build up cumulative pruning. Does NOT clear existing pruning state - adds to it.

Parameters:

feature_index (int) – Index of the feature to prune (0 to num_leaves-1).

reset_optimizer(learning_rate=0.2)

Reset the optimizer for the model. :param learning_rate: Learning rate for the optimizer. Defaults to 0.2. :type learning_rate: float, optional

save_model(file_name)

Save the model state to a file.

Parameters:

file_name (str) – Path to save the model.

train_model(x, y, epochs, is_frozen)

Train the binary tree logic network.

Parameters:
  • x (torch.Tensor) – Input tensor of shape (batch_size, input_size).

  • y (torch.Tensor) – Target tensor of shape (batch_size, 1).

  • epochs (int) – Number of training epochs.

Returns:

History of loss values during training.

Return type:

list

unharden_alternating_tree() None

Revert alternating tree hardening to allow continued training.

unharden_full_tree() None

Revert full tree hardening to allow continued training.