src.canns.models.basic

Submodules

Classes

CANN1D

A standard 1D Continuous Attractor Neural Network (CANN) model.

CANN1D_SFA

A 1D CANN model that incorporates Spike-Frequency Adaptation (SFA).

CANN2D

A 2D Continuous Attractor Neural Network (CANN) model.

CANN2D_SFA

A 2D Continuous Attractor Neural Network (CANN) model with a specific

GridCell2DPosition

Position-based 2D continuous-attractor grid cell network with hexagonal lattice structure.

GridCell2DVelocity

Velocity-based grid cell network (Burak & Fiete 2009).

HierarchicalNetwork

A full hierarchical network composed of multiple grid modules.

Package Contents

class src.canns.models.basic.CANN1D(*args, **kwargs)[source]

Bases: BaseCANN1D

A standard 1D Continuous Attractor Neural Network (CANN) model. This model implements the core dynamics where a localized “bump” of activity can be sustained and moved by external inputs.

Reference:

Wu, S., Hamaguchi, K., & Amari, S. I. (2008). Dynamics and computation of continuous attractors. Neural computation, 20(4), 994-1025.

Initializes the 1D CANN model.

Parameters:

BaseCANN1D) ((Parameters are inherited from)

update(inp)[source]

The main update function, defining the dynamics of the network for one time step.

Parameters:

inp (Array) – The external input for the current time step.

inp
r
u
class src.canns.models.basic.CANN1D_SFA(num, tau=1.0, tau_v=50.0, k=8.1, a=0.3, A=0.2, J0=1.0, z_min=-bm.pi, z_max=bm.pi, m=0.3, **kwargs)[source]

Bases: BaseCANN1D

A 1D CANN model that incorporates Spike-Frequency Adaptation (SFA). SFA is a slow negative feedback mechanism that causes neurons to fire less over time for a sustained input, which can induce anticipative tracking behavior.

Reference:

Mi, Y., Fung, C. C., Wong, K. Y., & Wu, S. (2014). Spike frequency adaptation implements anticipative tracking in continuous attractor neural networks. Advances in neural information processing systems, 27.

Initializes the 1D CANN model with SFA.

Parameters:
  • tau_v (float) – The time constant for the adaptation variable ‘v’. A larger value means slower adaptation.

  • m (float) – The strength of the adaptation, coupling the membrane potential ‘u’ to the adaptation variable ‘v’.

  • BaseCANN1D) ((Other parameters are inherited from)

update(inp)[source]

The main update function for the SFA model. It includes dynamics for both the membrane potential and the adaptation variable.

Parameters:

inp (Array) – The external input for the current time step.

inp
m = 0.3
r
tau_v = 50.0
u
v
class src.canns.models.basic.CANN2D(*args, **kwargs)[source]

Bases: BaseCANN2D

A 2D Continuous Attractor Neural Network (CANN) model. This model extends the base CANN2D class to include specific dynamics and properties for a 2D neural network.

Reference:

Wu, S., Hamaguchi, K., & Amari, S. I. (2008). Dynamics and computation of continuous attractors. Neural computation, 20(4), 994-1025.

Initializes the 2D CANN model.

Parameters:

BaseCANN2D) ((Parameters are inherited from)

update(inp)[source]

The main update function, defining the dynamics of the network for one time step.

Parameters:

inp (Array) – The external input to the network, which can be a stimulus or other driving force.

inp
r
u
class src.canns.models.basic.CANN2D_SFA(length, tau=1.0, tau_v=50.0, k=8.1, a=0.3, A=0.2, J0=1.0, z_min=-bm.pi, z_max=bm.pi, m=0.3, **kwargs)[source]

Bases: BaseCANN2D

A 2D Continuous Attractor Neural Network (CANN) model with a specific implementation of the Synaptic Firing Activity (SFA) dynamics. This model extends the base CANN2D class to include SFA-specific dynamics.

Initializes the 2D CANN model with SFA dynamics.

update(inp)[source]

The main update function for the SFA model. It includes dynamics for both the membrane potential and the adaptation variable.

Parameters:

inp (Array) – The external input for the current time step.

inp
m = 0.3
r
tau_v = 50.0
u
v
class src.canns.models.basic.GridCell2DPosition(length=30, tau=10.0, k=1.0, a=0.8, A=3.0, J0=5.0, mapping_ratio=1.5, noise_strength=0.1, conn_noise=0.0, g=1.0)[source]

Bases: src.canns.models.basic._base.BasicModel

Position-based 2D continuous-attractor grid cell network with hexagonal lattice structure.

This network implements a twisted torus topology that generates grid cell-like spatial representations with hexagonal periodicity.

The network operates in a transformed coordinate system where grid cells form a hexagonal lattice, enabling realistic grid field spacing and orientation.

Parameters:
  • length (int) – Number of grid cells along one dimension (total = length^2). Default: 30

  • tau (float) – Membrane time constant (ms). Default: 10.0

  • k (float) – Global inhibition strength for divisive normalization. Default: 1.0

  • a (float) – Width of connectivity kernel. Determines bump width. Default: 0.8

  • A (float) – Amplitude of external input. Default: 3.0

  • J0 (float) – Peak recurrent connection strength. Default: 5.0

  • mapping_ratio (float) – Controls grid spacing (larger = smaller spacing). Grid spacing λ = 2π / mapping_ratio. Default: 1.5

  • noise_strength (float) – Standard deviation of activity noise. Default: 0.1

  • conn_noise (float) – Standard deviation of connectivity noise. Default: 0.0

  • g (float) – Firing rate gain factor (scales to biological range). Default: 1.0

num

Total number of grid cells (length^2)

Type:

int

x_grid, y_grid

Grid cell preferred phases in [-π, π]

Type:

Array

value_grid

Neuron positions in phase space, shape (num, 2)

Type:

Array

Lambda

Grid spacing in real space

Type:

float

coor_transform

Hexagonal to rectangular coordinate transform

Type:

Array

coor_transform_inv

Rectangular to hexagonal coordinate transform

Type:

Array

conn_mat

Recurrent connectivity matrix

Type:

Array

candidate_centers

Grid of candidate bump centers for decoding

Type:

Array

r

Firing rates (shape: num)

Type:

Variable

u

Membrane potentials (shape: num)

Type:

Variable

center_phase

Decoded bump center in phase space (shape: 2)

Type:

Variable

center_position

Decoded position in real space (shape: 2)

Type:

Variable

inp

External input for tracking (shape: num)

Type:

Variable

gc_bump

Grid cell bump activity pattern (shape: num)

Type:

Variable

Example

>>> import brainpy.math as bm
>>> from canns.models.basic import GridCell2D
>>>
>>> bm.set_dt(1.0)
>>> model = GridCell2D(length=30, mapping_ratio=1.5)
>>>
>>> # Update with 2D position
>>> position = [0.5, 0.3]
>>> model.update(position)
>>>
>>> # Access decoded position
>>> decoded_pos = model.center_position.value
>>> print(f"Decoded position: {decoded_pos}")

References

Burak, Y., & Fiete, I. R. (2009). Accurate path integration in continuous attractor network models of grid cells. PLoS Computational Biology, 5(2), e1000291.

Initialize the simplified grid cell network.

calculate_dist(d)[source]

Calculate Euclidean distance after hexagonal coordinate transformation.

Applies periodic boundary conditions and transforms displacement vectors from phase space to hexagonal lattice coordinates.

Parameters:

d – Displacement vectors in phase space, shape (…, 2)

Returns:

Euclidean distances in hexagonal space

Return type:

Array of shape (…,)

get_stimulus_by_pos(position)[source]

Generate Gaussian stimulus centered at given position.

Useful for previewing input patterns without calling update.

Parameters:

position – 2D position [x, y] in real space

Returns:

External input pattern

Return type:

Array of shape (num,)

get_unique_activity_bump(network_activity, animal_position)[source]

Decode unique bump location from network activity and animal position.

Estimates the activity bump center in phase space using population vector decoding, then maps it to real space and snaps to the nearest candidate center to resolve periodic ambiguity.

Parameters:
  • network_activity – Grid cell firing rates, shape (num,)

  • animal_position – Current animal position for disambiguation, shape (2,)

Returns:

Phase coordinates of bump center, shape (2,) center_position: Real-space position of bump (nearest candidate), shape (2,) bump: Gaussian bump template centered at center_position, shape (num,)

Return type:

center_phase

handle_periodic_condition(d)[source]

Apply periodic boundary conditions to wrap phases into [-π, π].

Parameters:

d – Phase values (any shape with last dimension = 2)

Returns:

Wrapped phase values in [-π, π]

make_candidate_centers(Lambda)[source]

Generate grid of candidate bump centers for decoding.

Creates a regular lattice of potential activity bump locations used for disambiguating position from grid cell phases.

Parameters:

Lambda – Grid spacing in real space

Returns:

Candidate centers in transformed coordinates

Return type:

Array of shape (N_c*N_c, 2)

make_connection()[source]

Generate recurrent connectivity matrix with 2D Gaussian kernel.

Uses hexagonal lattice geometry via coordinate transformation. Connection strength decays with distance in transformed space.

Returns:

Recurrent connectivity matrix

Return type:

Array of shape (num, num)

position2phase(position)[source]

Convert real-space position to grid cell phase coordinates.

Applies coordinate transformation and wraps to periodic boundaries. Each grid cell’s preferred phase is determined by the animal’s position on the hexagonal lattice.

Parameters:

position – Real-space coordinates, shape (2,) or (2, N)

Returns:

Phase coordinates in [-π, π] per axis

Return type:

Array of shape (2,) or (2, N)

update(position)[source]

Single time-step update of grid cell network dynamics.

Updates network activity using continuous attractor dynamics with direct position-based external input. No adaptation or theta modulation.

Parameters:

position – Current 2D position [x, y] in real space, shape (2,)

A = 3.0
J0 = 5.0
Lambda
a = 0.8
candidate_centers
center_phase
center_position
conn_mat
conn_noise = 0.0
coor_transform
coor_transform_inv
g = 1.0
gc_bump
inp
k = 1.0
length = 30
mapping_ratio = 1.5
noise_strength = 0.1
num = 900
r
tau = 10.0
u
value_bump
value_grid
x_grid
y_grid
class src.canns.models.basic.GridCell2DVelocity(length=40, tau=0.01, alpha=0.2, A=1.0, W_a=1.5, W_l=2.0, lambda_net=15.0, e=1.15, use_sparse=False)[source]

Bases: src.canns.models.basic._base.BasicModel

Velocity-based grid cell network (Burak & Fiete 2009).

This network implements path integration through velocity-modulated input and asymmetric connectivity. Unlike position-based models, this takes velocity as input and integrates it over time to track position.

Key Features:
  • Velocity-dependent input modulation: B(v) = A * (1 + α·v·v_pref)

  • Asymmetric connectivity shifted in preferred velocity directions

  • Simple ReLU activation (not divisive normalization)

  • Healing process for proper initialization

Parameters:
  • length (int) – Number of neurons along one dimension (total = length²). Default: 40

  • tau (float) – Membrane time constant. Default: 0.01

  • alpha (float) – Velocity coupling strength. Default: 0.2

  • A (float) – Baseline input amplitude. Default: 1.0

  • W_a (float) – Connection amplitude (>1 makes close surround activatory). Default: 1.5

  • W_l (float) – Spatial shift size for asymmetric connectivity. Default: 2.0

  • lambda_net (float) – Lattice constant (neurons between bump centers). Default: 15.0

  • e (float) – Controls inhibitory surround spread. Default: 1.15 W_gamma and W_beta are computed from this and lambda_net

num

Total number of neurons (length²)

Type:

int

positions

Neuron positions in 2D lattice, shape (num, 2)

Type:

Array

vec_pref

Preferred velocity directions (unit vectors), shape (num, 2)

Type:

Array

conn_mat

Asymmetric connectivity matrix, shape (num, num)

Type:

Array

s

Neural activity/potential, shape (num,)

Type:

Variable

r

Firing rates (ReLU of s), shape (num,)

Type:

Variable

center_position

Decoded position in real space, shape (2,)

Type:

Variable

Example

>>> import brainpy.math as bm
>>> from canns.models.basic import GridCell2DVelocity
>>>
>>> bm.set_dt(5e-4)  # Small timestep for accurate integration
>>> model = GridCell2DVelocity(length=40)
>>>
>>> # Healing process (critical!)
>>> model.heal_network()
>>>
>>> # Update with 2D velocity
>>> velocity = [0.1, 0.05]  # [vx, vy] in m/s
>>> model.update(velocity)

References

Burak, Y., & Fiete, I. R. (2009). Accurate path integration in continuous attractor network models of grid cells. PLoS Computational Biology, 5(2), e1000291.

Initialize the Burak & Fiete grid cell network.

Parameters:

use_sparse (bool) – Whether to use sparse matrix for connectivity (experimental). Default: False. Sparse matrices may be faster on GPU but slower on CPU. Requires brainevent library.

compute_velocity_input(velocity)[source]

Compute velocity-modulated input: B(v) = A * (1 + α·v·v_pref)

Neurons whose preferred direction aligns with the velocity receive stronger input, creating directional modulation that drives bump shifts.

Parameters:

velocity – 2D velocity vector [vx, vy], shape (2,)

Returns:

Input to each neuron

Return type:

Array of shape (num,)

decode_position_from_activity(activity)[source]

Decode position from neural activity using population vector method.

This method analyzes the activity bump to determine the network’s internal representation of position. Currently simplified.

Parameters:

activity – Neural activity, shape (num,)

Returns:

Decoded 2D position, shape (2,)

Return type:

position

decode_position_lsq(activity_history, velocity_history)[source]

Decode position using velocity integration (simple method).

For proper position decoding from neural activity, a more sophisticated method would fit the activity to spatial basis functions. For now, we use velocity integration as ground truth and compute error metrics.

Parameters:
  • activity_history – Neural activity over time, shape (T, num)

  • velocity_history – Velocity over time, shape (T, 2)

Returns:

Integrated positions, shape (T, 2) r_squared: R² score (comparing integrated vs true positions if available)

Return type:

decoded_positions

handle_periodic_condition(d)[source]

Apply periodic boundary conditions to neuron position differences.

Parameters:

d – Position differences, shape (…, 2)

Returns:

Wrapped differences with periodic boundaries

heal_network(num_healing_steps=2500, dt_healing=0.0001)[source]

Healing process to form stable activity bump before simulation (optimized).

This process is critical for proper initialization. It relaxes the network to a stable attractor state through a sequence of movements: 1. Relax with zero velocity (T=0.25s) 2. Move in 4 cardinal directions (0°, 90°, 180°, 270°) 3. Relax again with zero velocity (T=0.25s)

Parameters:
  • num_healing_steps – Total number of healing steps. Default: 2500

  • dt_healing – Small timestep for healing integration. Default: 1e-4

Note

This temporarily changes the global timestep. The original timestep is restored after healing. Uses bm.for_loop for efficient execution.

make_connection()[source]

Build asymmetric connectivity matrix with spatial shifts (vectorized).

The connectivity from neuron i to j depends on the distance between them, shifted by neuron i’s preferred velocity direction:

This creates asymmetric connectivity that enables velocity-driven bump shifts for path integration.

Connectivity kernel:

W_ij = W_a * (exp(-W_gamma * d²) - exp(-W_beta * d²))

Note

This implementation uses JAX broadcasting for efficient computation. All pairwise distances are computed simultaneously, avoiding Python loops.

If use_sparse=True, converts to brainevent.CSR sparse matrix format. Sparse matrices reduce memory usage for large networks but may be slower on CPU. They are primarily intended for GPU acceleration.

Returns:

Dense array of shape (num, num), or brainevent.CSR if use_sparse=True

static track_blob_centers(activities, length)[source]

Track blob centers using Gaussian filtering and thresholding.

This is the robust method from Burak & Fiete 2009 reference implementation that achieves R² > 0.99 for path integration quality.

Parameters:
  • activities – Neural activities, shape (T, num)

  • length – Grid size (e.g., 40 for 40×40 grid)

Returns:

Blob centers in neuron coordinates, shape (T, 2)

Return type:

centers

Example

>>> activities = np.array([...])  # (T, 1600) for 40×40 grid
>>> centers = GridCell2DVelocity.track_blob_centers(activities, length=40)
>>> # centers.shape == (T, 2)
update(velocity)[source]

Single timestep update with velocity input.

Dynamics:

ds/dt = (1/tau) * [-s + W·r + B(v)] r = ReLU(s) = max(s, 0)

Parameters:

velocity – 2D velocity [vx, vy], shape (2,)

A = 1.0
W_a = 1.5
W_beta = 0.013333333333333334
W_gamma = 0.015333333333333332
W_l = 2.0
alpha = 0.2
center_position
conn_mat
e = 1.15
lambda_net = 15.0
length = 40
num = 1600
positions
r
s
tau = 0.01
use_sparse = False
vec_pref
class src.canns.models.basic.HierarchicalNetwork(num_module, num_place, spacing_min=2.0, spacing_max=5.0, module_angle=0.0, band_size=180, band_noise=0.0, band_w_L2S=0.2, band_w_S2L=1.0, band_gain=0.2, grid_num=20, grid_tau=0.1, grid_tau_v=10.0, grid_k=0.005, grid_a=bm.pi / 9, grid_A=1.0, grid_J0=1.0, grid_mbar=1.0, gauss_tau=1.0, gauss_J0=1.1, gauss_k=0.0005, gauss_a=2 / 9 * bm.pi, nonrec_tau=0.1)[source]

Bases: src.canns.models.basic._base.BasicModelGroup

A full hierarchical network composed of multiple grid modules.

This class creates and manages a collection of HierarchicalPathIntegrationModel modules, each with a different grid spacing. By combining the outputs of these modules, the network can represent position unambiguously over a large area. The final output is a population of place cells whose activities are used to decode the animal’s estimated position.

num_module

The number of grid modules in the network.

Type:

int

num_place

The number of place cells in the output layer.

Type:

int

place_center

The center locations of the place cells.

Type:

bm.math.ndarray

MEC_model_list

A list containing all the HierarchicalPathIntegrationModel instances.

Type:

list

grid_fr

The firing rates of the grid cell population.

Type:

bm.Variable

band_x_fr

The firing rates of the x-oriented band cell population.

Type:

bm.Variable

band_y_fr

The firing rates of the y-oriented band cell population.

Type:

bm.Variable

place_fr

The firing rates of the place cell population.

Type:

bm.Variable

decoded_pos

The final decoded 2D position.

Type:

bm.Variable

References

Anonymous Author(s) “Unfolding the Black Box of Recurrent Neural Networks for Path Integration” (under review).

Initializes the HierarchicalNetwork.

Parameters:
  • num_module (int) – The number of grid modules to create.

  • num_place (int) – The number of place cells along one dimension of a square grid.

  • spacing_min (float, optional) – Minimum spacing for grid modules. Defaults to 2.0.

  • spacing_max (float, optional) – Maximum spacing for grid modules. Defaults to 5.0.

  • module_angle (float, optional) – Base orientation angle for all modules. Defaults to 0.0.

  • band_size (int, optional) – Number of neurons in each BandCell group. Defaults to 180.

  • band_noise (float, optional) – Noise level for BandCells. Defaults to 0.0.

  • band_w_L2S (float, optional) – Weight from band cells to shifter units. Defaults to 0.2.

  • band_w_S2L (float, optional) – Weight from shifter units to band cells. Defaults to 1.0.

  • band_gain (float, optional) – Gain factor for velocity signal in BandCells. Defaults to 0.2.

  • grid_num (int, optional) – Number of neurons per dimension for GridCell. Defaults to 20.

  • grid_tau (float, optional) – Synaptic time constant for GridCell. Defaults to 0.1.

  • grid_tau_v (float, optional) – Adaptation time constant for GridCell. Defaults to 10.0.

  • grid_k (float, optional) – Global inhibition strength for GridCell. Defaults to 5e-3.

  • grid_a (float, optional) – Connection width for GridCell. Defaults to pi/9.

  • grid_A (float, optional) – External input magnitude for GridCell. Defaults to 1.0.

  • grid_J0 (float, optional) – Maximum connection strength for GridCell. Defaults to 1.0.

  • grid_mbar (float, optional) – Base adaptation strength for GridCell. Defaults to 1.0.

  • gauss_tau (float, optional) – Time constant for GaussRecUnits in BandCells. Defaults to 1.0.

  • gauss_J0 (float, optional) – Connection strength scaling for GaussRecUnits. Defaults to 1.1.

  • gauss_k (float, optional) – Global inhibition for GaussRecUnits. Defaults to 5e-4.

  • gauss_a (float, optional) – Connection width for GaussRecUnits. Defaults to 2/9*pi.

  • nonrec_tau (float, optional) – Time constant for NonRecUnits in BandCells. Defaults to 0.1.

update(velocity, loc, loc_input_stre=0.0)[source]
MEC_model_list = []
band_x_fr
band_y_fr
decoded_pos
grid_fr
num_module
num_place
place_center
place_fr