How to Build a CANN Model?

Goal: By the end of this guide, you’ll be able to create and run a basic CANN model.

Estimated Reading Time: 10 minutes


Introduction

In this library, we implement the mathematically tractable and canonical continuous attractor neural network called the Wu-Amari-Wong (WAW) model [5, 6, 7, 8] as our standard implementation. Thanks to seamless integration with BrainPy [18]—a powerful brain dynamics programming framework built on JAX—building this mathematically rigorous model becomes remarkably simple. This guide shows you how to:

  1. Set up the BrainPy environment

  2. Create a CANN1D model instance

  3. Initialize the model state

  4. Run a simple forward pass

The Basics: BrainPy Framework

CANNs models are built using BrainPy, which provides:

  • Unified time-step management via brainpy.math

  • State containers (bm.Variable) for managing neural dynamics

  • JIT compilation through bm.for_loop for high performance

  • Automatic differentiation support for gradient-based analysis

All CANN models inherit from bp.DynamicalSystem, which means they follow a consistent interface across the library.

Step-by-Step: Creating Your First CANN

1. Set the Time Step

Before creating any model, you must set the simulation time step:

[1]:
import brainpy.math as bm

# Set time step to 0.1 ms (or your preferred value)
bm.set_dt(0.1)

Why this matters: The time step dt controls the granularity of your simulation. All models in your session will use this value for their dynamics updates.

2. Import and Create the Model

[2]:
from canns.models.basic import CANN1D

# Create a 1D CANN with 512 neurons
cann = CANN1D(num=512)

What’s happening here: - num=512 specifies the number of neurons in the network - The model automatically sets up connection weights, neuron positions, and dynamics parameters - Default parameters (e.g., connection strength k, time constant tau) are used unless you specify otherwise

3. Run a Forward Pass

Now you can call the model to update its state:

[3]:
import jax.numpy as jnp

# Create a simple external input (stimulus at position 0)
external_input = jnp.zeros(512)

# Run one time step
cann(external_input)

# Access the model's current state
print("Synaptic input:", cann.u.value)
print("Neural activity:", cann.r.value)
Synaptic input: [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0.]
Neural activity: [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0.]

What’s happening:

  • The model takes external input and updates its internal dynamics

  • cann.u stores synaptic input (membrane potential)

  • cann.r stores neural firing rates (activity)

  • Each call to cann(...) advances the model by one time step (dt)

Complete Working Example

Here’s a minimal, runnable example that puts it all together:

[4]:
import brainpy.math as bm  # :cite:p:`wang2023brainpy`
import jax.numpy as jnp
from canns.models.basic import CANN1D

# Step 1: Set time step
bm.set_dt(0.1)

# Step 2: Create model
cann = CANN1D(num=512)

# Step 3: Create a Gaussian bump stimulus centered at position 0
positions = cann.x  # Neuron positions from -pi to pi
stimulus = jnp.exp(-0.5 * (positions - 0.0)**2 / 0.25**2)

# Step 4: Run several forward pass
cann(stimulus)
cann(stimulus)
cann(stimulus)

# Step 5: Check the output
print(f"Activity shape: {cann.r.value.shape}")
print(f"Max activity: {jnp.max(cann.r.value)}")
Activity shape: (512,)
Max activity: 0.002971156034618616
[ ]:
cann = CANN1D(
    num=512,           # Number of neurons
    k=1.0,             # Global connection strength
    tau=1.0,           # Time constant (ms)
    a=0.5,             # Width of excitatory connections
    A=10.0,            # Amplitude of excitatory connections
    J0=1.0,            # External input strength
)

Key parameters:

  • num: Number of neurons (higher = finer spatial resolution, but slower)

  • k: Controls overall connection strength (higher = stronger self-organization)

  • tau: Time constant of dynamics (higher = slower changes)

  • a: Width of connection kernel (controls bump width)

  • A: Amplitude of connections (affects stability)

For most applications, the defaults work well. We’ll explore parameter tuning in the Core Concepts section.

Running Multiple Time Steps

In practice, you’ll run many time steps in a loop. BrainPy provides optimized tools for this:

[5]:
def step_function(t, stimulus):
    """Run one time step of the model."""
    cann(stimulus)
    return cann.r.value  # Return activity for each step

# Create stimuli for 100 time steps (here, constant stimulus)
stimuli = jnp.tile(stimulus, (100, 1))

# Run optimized loop with progress bar
activities = bm.for_loop(
    step_function,
    operands=(jnp.arange(0, 100), stimuli),  # Number of steps and input data
    progress_bar=10  # Show progress (updates every 10%)
)

print(f"Recorded activities shape: {activities.shape}")  # (100, 512)
Recorded activities shape: (100, 512)

What’s happening:

  • bm.for_loop JIT-compiles the loop for speed

  • Progress bar shows simulation progress (updates every 10%)

  • The result is a JAX array of all recorded activities

Common Mistakes and How to Avoid Them

❌ Mistake 1: Wrong Input Dimensions

[6]:
cann = CANN1D(num=512)
try:
    cann(jnp.zeros(256))  # ERROR! Input size doesn't match num neurons
except Exception as e:
    print(f"Caught error as expected: {e}")
Caught error as expected: The shape of the original data is (512,), while we got (256,) with batch_axis=None.

✅ Solution: Input must have the same size as num:

[7]:
cann = CANN1D(num=512)
cann(jnp.zeros(512))  # Correct size

❌ Mistake 2: Not Setting the Time Step

[8]:
from canns.models.basic import CANN1D
cann = CANN1D(num=512)  # Uses whatever dt was set before (or default)

✅ Solution: Explicitly set dt at the start of your script:

[ ]:
import brainpy.math as bm  # :cite:p:`wang2023brainpy`
bm.set_dt(0.1)  # Set dt first
cann = CANN1D(num=512)

What About 2D CANNs?

The same principles apply to 2D models:

[9]:
from canns.models.basic import CANN2D

bm.set_dt(0.1)

# Create 2D CANN with 32x32 neurons
cann2d = CANN2D(32)

# Input must be (32, 32) shaped
stimulus_2d = jnp.zeros((32, 32))
cann2d(stimulus_2d)

print(f"2D activity shape: {cann2d.r.value.shape}")  # (32, 32)
2D activity shape: (32, 32)

The API is nearly identical—just adapt your input dimensions!

Next Steps

Now that you know how to create and run CANN models, you’re ready to:


Have questions? Open a GitHub Discussion.