canns.models.brain_inspired.linear

Generic linear layer for brain-inspired learning algorithms.

Classes

LinearLayer

Generic linear feedforward layer for brain-inspired learning rules.

Module Contents

class canns.models.brain_inspired.linear.LinearLayer(input_size, output_size, use_bcm_threshold=False, threshold_tau=100.0, **kwargs)[source]

Bases: canns.models.brain_inspired._base.BrainInspiredModel

Generic linear feedforward layer for brain-inspired learning rules.

It computes a simple linear transform:

y = W @ x

You can pair it with trainers like OjaTrainer, BCMTrainer, or HebbianTrainer.

Examples

>>> import jax.numpy as jnp
>>> from canns.models.brain_inspired import LinearLayer
>>>
>>> layer = LinearLayer(input_size=3, output_size=2)
>>> y = layer.forward(jnp.array([1.0, 0.5, -1.0], dtype=jnp.float32))
>>> y.shape
(2,)

References

  • Oja (1982): Simplified neuron model as a principal component analyzer

  • Bienenstock et al. (1982): Theory for the development of neuron selectivity

Initialize the linear layer.

Parameters:
  • input_size (int) – Dimensionality of input vectors

  • output_size (int) – Number of output neurons (features to extract)

  • use_bcm_threshold (bool) – Whether to maintain sliding threshold for BCM learning

  • threshold_tau (float) – Time constant for threshold sliding average (only used if use_bcm_threshold=True)

  • **kwargs – Additional arguments passed to parent class

forward(x)[source]

Compute the layer output for one input vector.

Parameters:

x (jax.numpy.ndarray) – Input vector of shape (input_size,).

Returns:

Output vector of shape (output_size,).

Return type:

jax.numpy.ndarray

resize(input_size, output_size=None, preserve_submatrix=True)[source]

Resize layer dimensions.

Parameters:
  • input_size (int) – New input dimension

  • output_size (int | None) – New output dimension (if None, keep current)

  • preserve_submatrix (bool) – Whether to preserve existing weights

update(prev_energy)[source]

Update method for trainer compatibility (no-op for feedforward layer).

update_threshold()[source]

Update the sliding threshold based on recent activity (BCM only).

This method should be called by BCMTrainer after each forward pass. Updates θ using: θ ← θ + (1/τ) * (y² - θ)

W[source]
property energy: float[source]

Energy for trainer compatibility (0 for feedforward layer).

input_size[source]
output_size[source]
property predict_state_attr: str[source]

Name of output state for prediction.

threshold_tau = 100.0[source]
use_bcm_threshold = False[source]
property weight_attr: str[source]

Name of weight parameter for generic training.

x[source]
y[source]