如何训练脑启发模型?

目标: 阅读完本指南后,您将能够使用 Hebbian 学习训练一个 Hopfield 网络来存储和提取模式。

预计阅读时间: 12 分钟

引言

与使用反向传播训练的深度学习模型不同, 脑启发模型 使用生物学上合理的学习规则,例如 Hebbian 学习 [22]

本指南将介绍:

  1. Hebbian 学习原理

  2. 训练用于模式记忆的 Hopfield 网络

  3. 使用训练器框架

  4. 与深度学习的关键区别

我们将训练一个模型来记忆和提取图像——以此展示联想记忆的实际应用。

什么是 Hebbian 学习?

Hebbian 学习 [22] 是一种无监督学习规则,当突触前和突触后神经元共同激活时,突触权重会增强:

Δw_ij = η × x_i × x_j

其中:

  • w_ij: 从神经元 i 到神经元 j 的连接权重

  • x_i, x_j: 神经元 ij 的活动

  • η: 学习率(在简单规则中通常设为 1)

关键特性:

  • 局部性: 权重更新仅取决于相连神经元的活动(无需全局误差信号)

  • 无监督: 不需要标签或目标输出

  • 生物学合理性: 与真实神经元中观察到的突触可塑性相符

与反向传播的对比:

  • 反向传播:全局误差信号、有监督、需要可微的损失函数

  • Hebbian 学习:局部活动、无监督、无需梯度计算

Amari-Hopfield 网络

Amari-Hopfield 网络 [21, 23] 是一种将模式存储为稳定吸引子状态的循环网络。当输入一个部分或带噪声的模式时,网络会收敛到最接近的已存储记忆。

用例: 联想记忆、模式补全、错误纠正

[1]:
from canns.models.brain_inspired import AmariHopfieldNetwork  # :cite:p:`amari1977neural,hopfield1982neural`

# Create a Hopfield network with 16,384 neurons (128x128 flattened image)
model = AmariHopfieldNetwork(
    num_neurons=128 * 128,  # Image size when flattened
    asyn=False,             # Synchronous updates (all neurons update together)
    activation="sign"       # Binary activation: +1 or -1
)

参数:

  • num_neurons: 网络大小(必须与输入维度匹配)

  • asyn: 异步(一次一个神经元)更新与同步更新

  • activation: 二进制 Hopfield 网络用 “sign”,连续变体用 “tanh”

训练器框架

该库提供了一个 统一的训练器 API ,将训练逻辑抽象化:

  模型 + 训练器 + 数据 → 训练好的模型

**理念** (源自设计哲学文档):
  • 关注点分离: 模型定义动力学,训练器定义学习

  • 可复用性: 同一训练器适用于具有兼容学习规则的不同模型

  • 可组合性: 可堆叠训练器以进行多阶段训练

对于 Hebbian 学习,我们使用 HebbianTrainer:

[2]:
from canns.trainer import HebbianTrainer

trainer = HebbianTrainer(model)

关键方法:

  • trainer.train(data): 使用模式列表进行训练

  • trainer.predict(pattern): 提取单个模式

  • trainer.predict_batch(patterns): 批量提取(已编译以提升速度)

完整示例: 图像记忆

让我们训练一个 Amari-Hopfield 网络 [21] 来记忆 4 张图像,并从其损坏版本中将其提取出来。

第 1 步: 准备训练数据

[4]:
import numpy as np
import skimage.data
from skimage.color import rgb2gray
from skimage.transform import resize
from skimage.filters import threshold_mean

def preprocess_image(img, size=128):
    """Convert image to binary {-1, +1} pattern."""
    # Convert to grayscale if needed
    if img.ndim == 3:
        img = rgb2gray(img)

    # Resize to fixed size
    img = resize(img, (size, size), anti_aliasing=True)

    # Threshold to binary
    thresh = threshold_mean(img)
    binary = img > thresh

    # Map to {-1, +1}
    pattern = np.where(binary, 1.0, -1.0).astype(np.float32)

    # Flatten to 1D
    return pattern.reshape(size * size)

# Load example images from scikit-image
camera = preprocess_image(skimage.data.camera())
astronaut = preprocess_image(skimage.data.astronaut())
horse = preprocess_image(skimage.data.horse().astype(np.float32))
coffee = preprocess_image(skimage.data.coffee())

training_data = [camera, astronaut, horse, coffee]

print(f"Number of patterns: {len(training_data)}")
print(f"Pattern shape: {training_data[0].shape}")  # (16384,) = 128*128
Number of patterns: 4
Pattern shape: (16384,)

为什么使用二进制值 {-1, +1}?

  • 经典的 Amari-Hopfield 网络 [21] 使用二进制神经元

  • 这简化了能量函数和更新规则

  • 也存在实值扩展(使用 activation="tanh"

第 2 步: 创建模型和训练器

[5]:
from canns.models.brain_inspired import AmariHopfieldNetwork  # :cite:p:`amari1977neural,hopfield1982neural`
from canns.trainer import HebbianTrainer

# Create Hopfield network (auto-initializes)
model = AmariHopfieldNetwork(
    num_neurons=training_data[0].shape[0],
    asyn=False,
    activation="sign"
)

# Create Hebbian trainer
trainer = HebbianTrainer(model)

print("Model and trainer initialized!")
Model and trainer initialized!

第 3 步: 训练模型

[6]:
# Train on all patterns (this computes Hebbian weight matrix)
trainer.train(training_data)

print("Training complete! Patterns stored in weights.")
Training complete! Patterns stored in weights.

内部发生的事:

[ ]:
# Simplified version of Hebbian weight update
for pattern in training_data:
    W += np.outer(pattern, pattern)  # Hebbian: w_ij += x_i * x_j
W /= len(training_data)  # Normalize
np.fill_diagonal(W, 0)   # No self-connections

权重矩阵 W 现在已将所有训练模式编码为吸引子状态。

第 4 步:测试模式提取

创建训练图像的损坏版本:

[8]:
def corrupt_pattern(pattern, noise_level=0.3):
    """Randomly flip 30% of pixels."""
    corrupted = np.copy(pattern)
    num_flips = int(len(pattern) * noise_level)
    flip_indices = np.random.choice(len(pattern), num_flips, replace=False)
    corrupted[flip_indices] *= -1  # Flip sign
    return corrupted

# Create test patterns (30% corrupted)
test_patterns = [corrupt_pattern(p, 0.3) for p in training_data]

print(f"Created {len(test_patterns)} corrupted test patterns")
Created 4 corrupted test patterns

第 5 步: 提取模式

[9]:
# Batch prediction (compiled for efficiency)
recalled = trainer.predict_batch(test_patterns, show_sample_progress=True)

print("Pattern recall complete!")
print(f"Recalled patterns shape: {np.array(recalled).shape}")
Processing samples: 100%|█████████████| 4/4 [00:04<00:00,  1.15s/it, sample=4/4]
Pattern recall complete!
Recalled patterns shape: (4, 16384)

发生了什么:

  • 对于每个损坏的模式,网络会迭代其动力学过程

  • 网络活动会收敛到最接近的已存储吸引子(原始图像)

  • 结果就是被”清理干净”的模式

第 6 步: 可视化结果

[10]:
import matplotlib.pyplot as plt

def reshape_for_display(pattern, size=128):
    """Reshape 1D pattern back to 2D image."""
    return pattern.reshape(size, size)

# Plot original, corrupted, and recalled patterns
fig, axes = plt.subplots(len(training_data), 3, figsize=(8, 10))

for i in range(len(training_data)):
    # Column 1: Original training image
    axes[i, 0].imshow(reshape_for_display(training_data[i]), cmap='gray')
    axes[i, 0].axis('off')
    if i == 0:
        axes[i, 0].set_title('Original')

    # Column 2: Corrupted test input
    axes[i, 1].imshow(reshape_for_display(test_patterns[i]), cmap='gray')
    axes[i, 1].axis('off')
    if i == 0:
        axes[i, 1].set_title('Corrupted (30%)')

    # Column 3: Recalled output
    axes[i, 2].imshow(reshape_for_display(recalled[i]), cmap='gray')
    axes[i, 2].axis('off')
    if i == 0:
        axes[i, 2].set_title('Recalled')

plt.tight_layout()
plt.savefig('hopfield_memory_recall.png', dpi=150)
plt.show()

print("Visualization saved!")
../../_images/zh_1_quick_starts_05_train_brain_inspired_17_0.png
Visualization saved!

预期结果:

  • 原始图像是干净的

  • 损坏的输入含有约 30% 的噪声

  • 提取的输出与原始图像匹配(噪声已被纠正!)

完整可运行代码

以下是完整的单代码块示例:

[ ]:
import numpy as np
import skimage.data
from skimage.color import rgb2gray
from skimage.transform import resize
from skimage.filters import threshold_mean
from matplotlib import pyplot as plt

from canns.models.brain_inspired import AmariHopfieldNetwork  # :cite:p:`amari1977neural,hopfield1982neural`
from canns.trainer import HebbianTrainer

np.random.seed(42)

# 1. Preprocess images
def preprocess_image(img, size=128):
    if img.ndim == 3:
        img = rgb2gray(img)
    img = resize(img, (size, size), anti_aliasing=True)
    thresh = threshold_mean(img)
    binary = img > thresh
    pattern = np.where(binary, 1.0, -1.0).astype(np.float32)
    return pattern.reshape(size * size)

camera = preprocess_image(skimage.data.camera())
astronaut = preprocess_image(skimage.data.astronaut())
horse = preprocess_image(skimage.data.horse().astype(np.float32))
coffee = preprocess_image(skimage.data.coffee())

training_data = [camera, astronaut, horse, coffee]

# 2. Create model and trainer (auto-initializes)
model = AmariHopfieldNetwork(num_neurons=training_data[0].shape[0], asyn=False, activation="sign")
trainer = HebbianTrainer(model)

# 3. Train
trainer.train(training_data)

# 4. Create corrupted test patterns
def corrupt_pattern(pattern, noise_level=0.3):
    corrupted = np.copy(pattern)
    num_flips = int(len(pattern) * noise_level)
    flip_indices = np.random.choice(len(pattern), num_flips, replace=False)
    corrupted[flip_indices] *= -1
    return corrupted

test_patterns = [corrupt_pattern(p, 0.3) for p in training_data]

# 5. Recall patterns
recalled = trainer.predict_batch(test_patterns, show_sample_progress=True)

# 6. Visualize
def reshape(pattern, size=128):
    return pattern.reshape(size, size)

fig, axes = plt.subplots(len(training_data), 3, figsize=(8, 10))
for i in range(len(training_data)):
    axes[i, 0].imshow(reshape(training_data[i]), cmap='gray')
    axes[i, 0].axis('off')
    axes[i, 1].imshow(reshape(test_patterns[i]), cmap='gray')
    axes[i, 1].axis('off')
    axes[i, 2].imshow(reshape(recalled[i]), cmap='gray')
    axes[i, 2].axis('off')

axes[0, 0].set_title('Original')
axes[0, 1].set_title('Corrupted')
axes[0, 2].set_title('Recalled')

plt.tight_layout()
plt.savefig('hopfield_memory.png')
plt.show()

CANN 与深度学习训练对比

方面

Hebbian 学习

深度学习(人工神经网络)

学习规则

局部的(Hebbian 学习、STDP [24]

全局的(反向传播)

监督方式

无监督

有监督(通常)

梯度

无需梯度计算

需要自动微分

训练数据

待记忆的模式

带标签的输入-输出对

目标

形成吸引子

最小化损失函数

生物学合理性

速度

快(可能实现一次性学习)

慢(需要多个训练周期)

容量

有限(对于 N 个神经元,约可存储 0.15N 个模式)

非常大(过参数化)

何时使用:

  • Hebbian 学习/CANN: 联想记忆、模式补全、神经科学建模

  • 反向传播/人工神经网络: 分类、回归、大规模模式识别

训练器抽象

Trainer 框架为不同的学习范式提供了一致的接口:

[ ]:
# All trainers follow this pattern
trainer = SomeTrainer(model)
trainer.train(data)
output = trainer.predict(input)