canns.models.brain_inspired.linear¶
Generic linear layer for brain-inspired learning algorithms.
Classes¶
Generic linear feedforward layer for brain-inspired learning rules. |
Module Contents¶
- class canns.models.brain_inspired.linear.LinearLayer(input_size, output_size, use_bcm_threshold=False, threshold_tau=100.0, **kwargs)[source]¶
Bases:
canns.models.brain_inspired._base.BrainInspiredModelGeneric linear feedforward layer for brain-inspired learning rules.
- It computes a simple linear transform:
y = W @ x
You can pair it with trainers like
OjaTrainer,BCMTrainer, orHebbianTrainer.Examples
>>> import jax.numpy as jnp >>> from canns.models.brain_inspired import LinearLayer >>> >>> layer = LinearLayer(input_size=3, output_size=2) >>> y = layer.forward(jnp.array([1.0, 0.5, -1.0], dtype=jnp.float32)) >>> y.shape (2,)
References
Oja (1982): Simplified neuron model as a principal component analyzer
Bienenstock et al. (1982): Theory for the development of neuron selectivity
Initialize the linear layer.
- Parameters:
input_size (int) – Dimensionality of input vectors
output_size (int) – Number of output neurons (features to extract)
use_bcm_threshold (bool) – Whether to maintain sliding threshold for BCM learning
threshold_tau (float) – Time constant for threshold sliding average (only used if use_bcm_threshold=True)
**kwargs – Additional arguments passed to parent class
- forward(x)[source]¶
Compute the layer output for one input vector.
- Parameters:
x (jax.numpy.ndarray) – Input vector of shape
(input_size,).- Returns:
Output vector of shape
(output_size,).- Return type:
jax.numpy.ndarray