src.canns.models.basic¶
Submodules¶
Classes¶
A standard 1D Continuous Attractor Neural Network (CANN) model. |
|
A 1D CANN model that incorporates Spike-Frequency Adaptation (SFA). |
|
A 2D Continuous Attractor Neural Network (CANN) model. |
|
A 2D Continuous Attractor Neural Network (CANN) model with a specific |
|
Position-based 2D continuous-attractor grid cell network with hexagonal lattice structure. |
|
Velocity-based grid cell network (Burak & Fiete 2009). |
|
A full hierarchical network composed of multiple grid modules. |
Package Contents¶
- class src.canns.models.basic.CANN1D(*args, **kwargs)[source]¶
Bases:
BaseCANN1DA standard 1D Continuous Attractor Neural Network (CANN) model. This model implements the core dynamics where a localized “bump” of activity can be sustained and moved by external inputs.
- Reference:
Wu, S., Hamaguchi, K., & Amari, S. I. (2008). Dynamics and computation of continuous attractors. Neural computation, 20(4), 994-1025.
Initializes the 1D CANN model.
- Parameters:
BaseCANN1D) ((Parameters are inherited from)
- update(inp)[source]¶
The main update function, defining the dynamics of the network for one time step.
- Parameters:
inp (Array) – The external input for the current time step.
- inp¶
- r¶
- u¶
- class src.canns.models.basic.CANN1D_SFA(num, tau=1.0, tau_v=50.0, k=8.1, a=0.3, A=0.2, J0=1.0, z_min=-bm.pi, z_max=bm.pi, m=0.3, **kwargs)[source]¶
Bases:
BaseCANN1DA 1D CANN model that incorporates Spike-Frequency Adaptation (SFA). SFA is a slow negative feedback mechanism that causes neurons to fire less over time for a sustained input, which can induce anticipative tracking behavior.
- Reference:
Mi, Y., Fung, C. C., Wong, K. Y., & Wu, S. (2014). Spike frequency adaptation implements anticipative tracking in continuous attractor neural networks. Advances in neural information processing systems, 27.
Initializes the 1D CANN model with SFA.
- Parameters:
- update(inp)[source]¶
The main update function for the SFA model. It includes dynamics for both the membrane potential and the adaptation variable.
- Parameters:
inp (Array) – The external input for the current time step.
- inp¶
- m = 0.3¶
- r¶
- tau_v = 50.0¶
- u¶
- v¶
- class src.canns.models.basic.CANN2D(*args, **kwargs)[source]¶
Bases:
BaseCANN2DA 2D Continuous Attractor Neural Network (CANN) model. This model extends the base CANN2D class to include specific dynamics and properties for a 2D neural network.
- Reference:
Wu, S., Hamaguchi, K., & Amari, S. I. (2008). Dynamics and computation of continuous attractors. Neural computation, 20(4), 994-1025.
Initializes the 2D CANN model.
- Parameters:
BaseCANN2D) ((Parameters are inherited from)
- update(inp)[source]¶
The main update function, defining the dynamics of the network for one time step.
- Parameters:
inp (Array) – The external input to the network, which can be a stimulus or other driving force.
- inp¶
- r¶
- u¶
- class src.canns.models.basic.CANN2D_SFA(length, tau=1.0, tau_v=50.0, k=8.1, a=0.3, A=0.2, J0=1.0, z_min=-bm.pi, z_max=bm.pi, m=0.3, **kwargs)[source]¶
Bases:
BaseCANN2DA 2D Continuous Attractor Neural Network (CANN) model with a specific implementation of the Synaptic Firing Activity (SFA) dynamics. This model extends the base CANN2D class to include SFA-specific dynamics.
Initializes the 2D CANN model with SFA dynamics.
- update(inp)[source]¶
The main update function for the SFA model. It includes dynamics for both the membrane potential and the adaptation variable.
- Parameters:
inp (Array) – The external input for the current time step.
- inp¶
- m = 0.3¶
- r¶
- tau_v = 50.0¶
- u¶
- v¶
- class src.canns.models.basic.GridCell2DPosition(length=30, tau=10.0, k=1.0, a=0.8, A=3.0, J0=5.0, mapping_ratio=1.5, noise_strength=0.1, conn_noise=0.0, g=1.0)[source]¶
Bases:
src.canns.models.basic._base.BasicModelPosition-based 2D continuous-attractor grid cell network with hexagonal lattice structure.
This network implements a twisted torus topology that generates grid cell-like spatial representations with hexagonal periodicity.
The network operates in a transformed coordinate system where grid cells form a hexagonal lattice, enabling realistic grid field spacing and orientation.
- Parameters:
length (int) – Number of grid cells along one dimension (total = length^2). Default: 30
tau (float) – Membrane time constant (ms). Default: 10.0
k (float) – Global inhibition strength for divisive normalization. Default: 1.0
a (float) – Width of connectivity kernel. Determines bump width. Default: 0.8
A (float) – Amplitude of external input. Default: 3.0
J0 (float) – Peak recurrent connection strength. Default: 5.0
mapping_ratio (float) – Controls grid spacing (larger = smaller spacing). Grid spacing λ = 2π / mapping_ratio. Default: 1.5
noise_strength (float) – Standard deviation of activity noise. Default: 0.1
conn_noise (float) – Standard deviation of connectivity noise. Default: 0.0
g (float) – Firing rate gain factor (scales to biological range). Default: 1.0
- x_grid, y_grid
Grid cell preferred phases in [-π, π]
- Type:
Array
- value_grid¶
Neuron positions in phase space, shape (num, 2)
- Type:
Array
- coor_transform¶
Hexagonal to rectangular coordinate transform
- Type:
Array
- coor_transform_inv¶
Rectangular to hexagonal coordinate transform
- Type:
Array
- conn_mat¶
Recurrent connectivity matrix
- Type:
Array
- candidate_centers¶
Grid of candidate bump centers for decoding
- Type:
Array
- r¶
Firing rates (shape: num)
- Type:
Variable
- u¶
Membrane potentials (shape: num)
- Type:
Variable
- center_phase¶
Decoded bump center in phase space (shape: 2)
- Type:
Variable
- center_position¶
Decoded position in real space (shape: 2)
- Type:
Variable
- inp¶
External input for tracking (shape: num)
- Type:
Variable
- gc_bump¶
Grid cell bump activity pattern (shape: num)
- Type:
Variable
Example
>>> import brainpy.math as bm >>> from canns.models.basic import GridCell2D >>> >>> bm.set_dt(1.0) >>> model = GridCell2D(length=30, mapping_ratio=1.5) >>> >>> # Update with 2D position >>> position = [0.5, 0.3] >>> model.update(position) >>> >>> # Access decoded position >>> decoded_pos = model.center_position.value >>> print(f"Decoded position: {decoded_pos}")
References
Burak, Y., & Fiete, I. R. (2009). Accurate path integration in continuous attractor network models of grid cells. PLoS Computational Biology, 5(2), e1000291.
Initialize the simplified grid cell network.
- calculate_dist(d)[source]¶
Calculate Euclidean distance after hexagonal coordinate transformation.
Applies periodic boundary conditions and transforms displacement vectors from phase space to hexagonal lattice coordinates.
- Parameters:
d – Displacement vectors in phase space, shape (…, 2)
- Returns:
Euclidean distances in hexagonal space
- Return type:
Array of shape (…,)
- get_stimulus_by_pos(position)[source]¶
Generate Gaussian stimulus centered at given position.
Useful for previewing input patterns without calling update.
- Parameters:
position – 2D position [x, y] in real space
- Returns:
External input pattern
- Return type:
Array of shape (num,)
- get_unique_activity_bump(network_activity, animal_position)[source]¶
Decode unique bump location from network activity and animal position.
Estimates the activity bump center in phase space using population vector decoding, then maps it to real space and snaps to the nearest candidate center to resolve periodic ambiguity.
- Parameters:
network_activity – Grid cell firing rates, shape (num,)
animal_position – Current animal position for disambiguation, shape (2,)
- Returns:
Phase coordinates of bump center, shape (2,) center_position: Real-space position of bump (nearest candidate), shape (2,) bump: Gaussian bump template centered at center_position, shape (num,)
- Return type:
center_phase
- handle_periodic_condition(d)[source]¶
Apply periodic boundary conditions to wrap phases into [-π, π].
- Parameters:
d – Phase values (any shape with last dimension = 2)
- Returns:
Wrapped phase values in [-π, π]
- make_candidate_centers(Lambda)[source]¶
Generate grid of candidate bump centers for decoding.
Creates a regular lattice of potential activity bump locations used for disambiguating position from grid cell phases.
- Parameters:
Lambda – Grid spacing in real space
- Returns:
Candidate centers in transformed coordinates
- Return type:
Array of shape (N_c*N_c, 2)
- make_connection()[source]¶
Generate recurrent connectivity matrix with 2D Gaussian kernel.
Uses hexagonal lattice geometry via coordinate transformation. Connection strength decays with distance in transformed space.
- Returns:
Recurrent connectivity matrix
- Return type:
Array of shape (num, num)
- position2phase(position)[source]¶
Convert real-space position to grid cell phase coordinates.
Applies coordinate transformation and wraps to periodic boundaries. Each grid cell’s preferred phase is determined by the animal’s position on the hexagonal lattice.
- Parameters:
position – Real-space coordinates, shape (2,) or (2, N)
- Returns:
Phase coordinates in [-π, π] per axis
- Return type:
Array of shape (2,) or (2, N)
- update(position)[source]¶
Single time-step update of grid cell network dynamics.
Updates network activity using continuous attractor dynamics with direct position-based external input. No adaptation or theta modulation.
- Parameters:
position – Current 2D position [x, y] in real space, shape (2,)
- A = 3.0¶
- J0 = 5.0¶
- Lambda¶
- a = 0.8¶
- candidate_centers¶
- center_phase¶
- center_position¶
- conn_mat¶
- conn_noise = 0.0¶
- coor_transform¶
- coor_transform_inv¶
- g = 1.0¶
- gc_bump¶
- inp¶
- k = 1.0¶
- length = 30¶
- mapping_ratio = 1.5¶
- noise_strength = 0.1¶
- num = 900¶
- r¶
- tau = 10.0¶
- u¶
- value_bump¶
- value_grid¶
- x_grid¶
- y_grid¶
- class src.canns.models.basic.GridCell2DVelocity(length=40, tau=0.01, alpha=0.2, A=1.0, W_a=1.5, W_l=2.0, lambda_net=15.0, e=1.15, use_sparse=False)[source]¶
Bases:
src.canns.models.basic._base.BasicModelVelocity-based grid cell network (Burak & Fiete 2009).
This network implements path integration through velocity-modulated input and asymmetric connectivity. Unlike position-based models, this takes velocity as input and integrates it over time to track position.
- Key Features:
Velocity-dependent input modulation: B(v) = A * (1 + α·v·v_pref)
Asymmetric connectivity shifted in preferred velocity directions
Simple ReLU activation (not divisive normalization)
Healing process for proper initialization
- Parameters:
length (int) – Number of neurons along one dimension (total = length²). Default: 40
tau (float) – Membrane time constant. Default: 0.01
alpha (float) – Velocity coupling strength. Default: 0.2
A (float) – Baseline input amplitude. Default: 1.0
W_a (float) – Connection amplitude (>1 makes close surround activatory). Default: 1.5
W_l (float) – Spatial shift size for asymmetric connectivity. Default: 2.0
lambda_net (float) – Lattice constant (neurons between bump centers). Default: 15.0
e (float) – Controls inhibitory surround spread. Default: 1.15 W_gamma and W_beta are computed from this and lambda_net
- positions¶
Neuron positions in 2D lattice, shape (num, 2)
- Type:
Array
- vec_pref¶
Preferred velocity directions (unit vectors), shape (num, 2)
- Type:
Array
- conn_mat¶
Asymmetric connectivity matrix, shape (num, num)
- Type:
Array
- s¶
Neural activity/potential, shape (num,)
- Type:
Variable
- r¶
Firing rates (ReLU of s), shape (num,)
- Type:
Variable
- center_position¶
Decoded position in real space, shape (2,)
- Type:
Variable
Example
>>> import brainpy.math as bm >>> from canns.models.basic import GridCell2DVelocity >>> >>> bm.set_dt(5e-4) # Small timestep for accurate integration >>> model = GridCell2DVelocity(length=40) >>> >>> # Healing process (critical!) >>> model.heal_network() >>> >>> # Update with 2D velocity >>> velocity = [0.1, 0.05] # [vx, vy] in m/s >>> model.update(velocity)
References
Burak, Y., & Fiete, I. R. (2009). Accurate path integration in continuous attractor network models of grid cells. PLoS Computational Biology, 5(2), e1000291.
Initialize the Burak & Fiete grid cell network.
- Parameters:
use_sparse (bool) – Whether to use sparse matrix for connectivity (experimental). Default: False. Sparse matrices may be faster on GPU but slower on CPU. Requires brainevent library.
- compute_velocity_input(velocity)[source]¶
Compute velocity-modulated input: B(v) = A * (1 + α·v·v_pref)
Neurons whose preferred direction aligns with the velocity receive stronger input, creating directional modulation that drives bump shifts.
- Parameters:
velocity – 2D velocity vector [vx, vy], shape (2,)
- Returns:
Input to each neuron
- Return type:
Array of shape (num,)
- decode_position_from_activity(activity)[source]¶
Decode position from neural activity using population vector method.
This method analyzes the activity bump to determine the network’s internal representation of position. Currently simplified.
- Parameters:
activity – Neural activity, shape (num,)
- Returns:
Decoded 2D position, shape (2,)
- Return type:
position
- decode_position_lsq(activity_history, velocity_history)[source]¶
Decode position using velocity integration (simple method).
For proper position decoding from neural activity, a more sophisticated method would fit the activity to spatial basis functions. For now, we use velocity integration as ground truth and compute error metrics.
- Parameters:
activity_history – Neural activity over time, shape (T, num)
velocity_history – Velocity over time, shape (T, 2)
- Returns:
Integrated positions, shape (T, 2) r_squared: R² score (comparing integrated vs true positions if available)
- Return type:
decoded_positions
- handle_periodic_condition(d)[source]¶
Apply periodic boundary conditions to neuron position differences.
- Parameters:
d – Position differences, shape (…, 2)
- Returns:
Wrapped differences with periodic boundaries
- heal_network(num_healing_steps=2500, dt_healing=0.0001)[source]¶
Healing process to form stable activity bump before simulation (optimized).
This process is critical for proper initialization. It relaxes the network to a stable attractor state through a sequence of movements: 1. Relax with zero velocity (T=0.25s) 2. Move in 4 cardinal directions (0°, 90°, 180°, 270°) 3. Relax again with zero velocity (T=0.25s)
- Parameters:
num_healing_steps – Total number of healing steps. Default: 2500
dt_healing – Small timestep for healing integration. Default: 1e-4
Note
This temporarily changes the global timestep. The original timestep is restored after healing. Uses bm.for_loop for efficient execution.
- make_connection()[source]¶
Build asymmetric connectivity matrix with spatial shifts (vectorized).
The connectivity from neuron i to j depends on the distance between them, shifted by neuron i’s preferred velocity direction:
distance = |pos_j - pos_i - W_l * vec_pref_i|
This creates asymmetric connectivity that enables velocity-driven bump shifts for path integration.
- Connectivity kernel:
W_ij = W_a * (exp(-W_gamma * d²) - exp(-W_beta * d²))
Note
This implementation uses JAX broadcasting for efficient computation. All pairwise distances are computed simultaneously, avoiding Python loops.
If use_sparse=True, converts to brainevent.CSR sparse matrix format. Sparse matrices reduce memory usage for large networks but may be slower on CPU. They are primarily intended for GPU acceleration.
- Returns:
Dense array of shape (num, num), or brainevent.CSR if use_sparse=True
- static track_blob_centers(activities, length)[source]¶
Track blob centers using Gaussian filtering and thresholding.
This is the robust method from Burak & Fiete 2009 reference implementation that achieves R² > 0.99 for path integration quality.
- Parameters:
activities – Neural activities, shape (T, num)
length – Grid size (e.g., 40 for 40×40 grid)
- Returns:
Blob centers in neuron coordinates, shape (T, 2)
- Return type:
centers
Example
>>> activities = np.array([...]) # (T, 1600) for 40×40 grid >>> centers = GridCell2DVelocity.track_blob_centers(activities, length=40) >>> # centers.shape == (T, 2)
- update(velocity)[source]¶
Single timestep update with velocity input.
- Dynamics:
ds/dt = (1/tau) * [-s + W·r + B(v)] r = ReLU(s) = max(s, 0)
- Parameters:
velocity – 2D velocity [vx, vy], shape (2,)
- A = 1.0¶
- W_a = 1.5¶
- W_beta = 0.013333333333333334¶
- W_gamma = 0.015333333333333332¶
- W_l = 2.0¶
- alpha = 0.2¶
- center_position¶
- conn_mat¶
- e = 1.15¶
- lambda_net = 15.0¶
- length = 40¶
- num = 1600¶
- positions¶
- r¶
- s¶
- tau = 0.01¶
- use_sparse = False¶
- vec_pref¶
- class src.canns.models.basic.HierarchicalNetwork(num_module, num_place, spacing_min=2.0, spacing_max=5.0, module_angle=0.0, band_size=180, band_noise=0.0, band_w_L2S=0.2, band_w_S2L=1.0, band_gain=0.2, grid_num=20, grid_tau=0.1, grid_tau_v=10.0, grid_k=0.005, grid_a=bm.pi / 9, grid_A=1.0, grid_J0=1.0, grid_mbar=1.0, gauss_tau=1.0, gauss_J0=1.1, gauss_k=0.0005, gauss_a=2 / 9 * bm.pi, nonrec_tau=0.1)[source]¶
Bases:
src.canns.models.basic._base.BasicModelGroupA full hierarchical network composed of multiple grid modules.
This class creates and manages a collection of HierarchicalPathIntegrationModel modules, each with a different grid spacing. By combining the outputs of these modules, the network can represent position unambiguously over a large area. The final output is a population of place cells whose activities are used to decode the animal’s estimated position.
- place_center¶
The center locations of the place cells.
- Type:
bm.math.ndarray
- grid_fr¶
The firing rates of the grid cell population.
- Type:
bm.Variable
- band_x_fr¶
The firing rates of the x-oriented band cell population.
- Type:
bm.Variable
- band_y_fr¶
The firing rates of the y-oriented band cell population.
- Type:
bm.Variable
- place_fr¶
The firing rates of the place cell population.
- Type:
bm.Variable
- decoded_pos¶
The final decoded 2D position.
- Type:
bm.Variable
References
Anonymous Author(s) “Unfolding the Black Box of Recurrent Neural Networks for Path Integration” (under review).
Initializes the HierarchicalNetwork.
- Parameters:
num_module (int) – The number of grid modules to create.
num_place (int) – The number of place cells along one dimension of a square grid.
spacing_min (float, optional) – Minimum spacing for grid modules. Defaults to 2.0.
spacing_max (float, optional) – Maximum spacing for grid modules. Defaults to 5.0.
module_angle (float, optional) – Base orientation angle for all modules. Defaults to 0.0.
band_size (int, optional) – Number of neurons in each BandCell group. Defaults to 180.
band_noise (float, optional) – Noise level for BandCells. Defaults to 0.0.
band_w_L2S (float, optional) – Weight from band cells to shifter units. Defaults to 0.2.
band_w_S2L (float, optional) – Weight from shifter units to band cells. Defaults to 1.0.
band_gain (float, optional) – Gain factor for velocity signal in BandCells. Defaults to 0.2.
grid_num (int, optional) – Number of neurons per dimension for GridCell. Defaults to 20.
grid_tau (float, optional) – Synaptic time constant for GridCell. Defaults to 0.1.
grid_tau_v (float, optional) – Adaptation time constant for GridCell. Defaults to 10.0.
grid_k (float, optional) – Global inhibition strength for GridCell. Defaults to 5e-3.
grid_a (float, optional) – Connection width for GridCell. Defaults to pi/9.
grid_A (float, optional) – External input magnitude for GridCell. Defaults to 1.0.
grid_J0 (float, optional) – Maximum connection strength for GridCell. Defaults to 1.0.
grid_mbar (float, optional) – Base adaptation strength for GridCell. Defaults to 1.0.
gauss_tau (float, optional) – Time constant for GaussRecUnits in BandCells. Defaults to 1.0.
gauss_J0 (float, optional) – Connection strength scaling for GaussRecUnits. Defaults to 1.1.
gauss_k (float, optional) – Global inhibition for GaussRecUnits. Defaults to 5e-4.
gauss_a (float, optional) – Connection width for GaussRecUnits. Defaults to 2/9*pi.
nonrec_tau (float, optional) – Time constant for NonRecUnits in BandCells. Defaults to 0.1.
- MEC_model_list = []¶
- band_x_fr¶
- band_y_fr¶
- decoded_pos¶
- grid_fr¶
- num_module¶
- num_place¶
- place_center¶
- place_fr¶