如何构建 CANN 模型?

目标:完成本指南后,您将能够创建并运行一个基础的 CANN 模型。

预计阅读时间:10 分钟


介绍

在本库中,我们以经典的 Wu-Amari-Wong (WAW) 连续吸引子神经网络 (CANNs) [5, 6, 7, 8] 作为标准实现。得益于与 BrainPy [18]——一个基于 JAX 构建的强大脑动力学编程框架——的无缝集成,构建这一数学上严谨的模型变得非常简单。本指南将向您展示如何:

  1. 设置 BrainPy 环境

  2. 创建 CANN1D 模型实例

  3. 初始化模型状态

  4. 运行简单的前向传播

基础知识:BrainPy 框架

CANN 模型使用 BrainPy 构建,它提供:

  • 统一的时间步管理,通过 brainpy.math 实现

  • 状态容器 (bm.Variable),用于管理神经动力学

  • JIT 编译,通过 bm.for_loop 实现高性能

  • 自动微分 支持,用于基于梯度的分析

所有 CANN 模型都继承自 bp.DynamicalSystem,这意味着它们在库中遵循一致的接口。

逐步指南:创建您的第一个 CANN

1. 设置时间步

在创建任何模型之前,您必须设置模拟时间步:

[1]:
import brainpy.math as bm

# Set time step to 0.1 ms (or your preferred value)
bm.set_dt(0.1)

为什么这很重要:时间步 dt 控制您模拟的粒度。您会话中的所有模型都将使用此值进行动力学更新。

2. 导入并创建模型

[2]:
from canns.models.basic import CANN1D

# Create a 1D CANN with 512 neurons
cann = CANN1D(num=512)

这里发生了什么: - num=512 指定网络中的神经元数量 - 该模型自动设置连接权重、神经元位置和动力学参数 - 使用默认参数(例如连接强度 k、时间常数 tau),除非您指定其他参数

3. 运行前向传播

现在您可以调用模型来更新其状态:

[3]:
import jax.numpy as jnp

# Create a simple external input (stimulus at position 0)
external_input = jnp.zeros(512)

# Run one time step
cann(external_input)

# Access the model's current state
print("Synaptic input:", cann.u.value)
print("Neural activity:", cann.r.value)
Synaptic input: [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0.]
Neural activity: [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0.]

发生了什么

  • 模型接收外部输入并更新其内部动力学

  • cann.u 存储突触输入(膜电位)

  • cann.r 存储神经放电速率(活动)

  • 每次调用 cann(...) 将模型推进一个时间步(dt

完整的工作示例

这是一个将所有内容整合在一起的最小可运行示例:

[4]:
import brainpy.math as bm  # :cite:p:`wang2023brainpy`
import jax.numpy as jnp
from canns.models.basic import CANN1D

# Step 1: Set time step
bm.set_dt(0.1)

# Step 2: Create model
cann = CANN1D(num=512)

# Step 3: Create a Gaussian bump stimulus centered at position 0
positions = cann.x  # Neuron positions from -pi to pi
stimulus = jnp.exp(-0.5 * (positions - 0.0)**2 / 0.25**2)

# Step 4: Run several forward pass
cann(stimulus)
cann(stimulus)
cann(stimulus)

# Step 5: Check the output
print(f"Activity shape: {cann.r.value.shape}")
print(f"Max activity: {jnp.max(cann.r.value)}")
Activity shape: (512,)
Max activity: 0.002971156034618616
[ ]:
cann = CANN1D(
    num=512,           # Number of neurons
    k=1.0,             # Global connection strength
    tau=1.0,           # Time constant (ms)
    a=0.5,             # Width of excitatory connections
    A=10.0,            # Amplitude of excitatory connections
    J0=1.0,            # External input strength
)

关键参数

  • num: 神经元数量(数值越高 = 空间分辨率越精细,但速度越慢)

  • k: 控制整体连接强度(数值越高 = 自组织能力越强)

  • tau: 动力学的时间常数(数值越高 = 变化越慢)

  • a: 连接核的宽度(控制波包宽度)

  • A: 连接的幅度(影响稳定性)

对于大多数应用,默认参数表现良好。我们将在核心概念部分探索参数调整。

运行多个时间步

在实践中,您将在循环中运行许多时间步。BrainPy 为此提供了优化的工具:

[5]:
def step_function(t, stimulus):
    """Run one time step of the model."""
    cann(stimulus)
    return cann.r.value  # Return activity for each step

# Create stimuli for 100 time steps (here, constant stimulus)
stimuli = jnp.tile(stimulus, (100, 1))

# Run optimized loop with progress bar
activities = bm.for_loop(
    step_function,
    operands=(jnp.arange(0, 100), stimuli),  # Number of steps and input data
    progress_bar=10  # Show progress (updates every 10%)
)

print(f"Recorded activities shape: {activities.shape}")  # (100, 512)
Recorded activities shape: (100, 512)

发生了什么

  • bm.for_loop 对循环进行 JIT 编译以提高速度

  • 进度条显示模拟进度(每 10% 更新一次)

  • 结果是所有已记录活动的 JAX 数组

常见错误及如何避免

❌ 错误 1:错误的输入维度

[6]:
cann = CANN1D(num=512)
try:
    cann(jnp.zeros(256))  # ERROR! Input size doesn't match num neurons
except Exception as e:
    print(f"Caught error as expected: {e}")
Caught error as expected: The shape of the original data is (512,), while we got (256,) with batch_axis=None.

✅ 解决方案:输入必须与 num 的大小相同:

[7]:
cann = CANN1D(num=512)
cann(jnp.zeros(512))  # Correct size

❌ 错误 2: 没有设置时间步

[8]:
from canns.models.basic import CANN1D
cann = CANN1D(num=512)  # Uses whatever dt was set before (or default)

✅ 解决方案: 在脚本开始时显式设置 dt

[ ]:
import brainpy.math as bm  # :cite:p:`wang2023brainpy`
bm.set_dt(0.1)  # Set dt first
cann = CANN1D(num=512)

关于 2D CANN?

相同的原理也适用于 2D 模型:

[9]:
from canns.models.basic import CANN2D

bm.set_dt(0.1)

# Create 2D CANN with 32x32 neurons
cann2d = CANN2D(32)

# Input must be (32, 32) shaped
stimulus_2d = jnp.zeros((32, 32))
cann2d(stimulus_2d)

print(f"2D activity shape: {cann2d.r.value.shape}")  # (32, 32)
2D activity shape: (32, 32)

API 几乎相同——只需调整您的输入维度!

后续步骤

现在您知道如何创建和运行 CANN 模型,您已准备好:


有疑问? 打开 GitHub Discussion.