如何训练脑启发模型?¶
目标: 阅读完本指南后,您将能够使用 Hebbian 学习训练一个 Hopfield 网络来存储和提取模式。
预计阅读时间: 12 分钟
引言¶
与使用反向传播训练的深度学习模型不同, 脑启发模型 使用生物学上合理的学习规则,例如 Hebbian 学习 [22]。
本指南将介绍:
Hebbian 学习原理
训练用于模式记忆的 Hopfield 网络
使用训练器框架
与深度学习的关键区别
我们将训练一个模型来记忆和提取图像——以此展示联想记忆的实际应用。
什么是 Hebbian 学习?¶
Hebbian 学习 [22] 是一种无监督学习规则,当突触前和突触后神经元共同激活时,突触权重会增强:
Δw_ij = η × x_i × x_j
其中:
w_ij: 从神经元i到神经元j的连接权重x_i, x_j: 神经元i和j的活动η: 学习率(在简单规则中通常设为 1)
关键特性:
局部性: 权重更新仅取决于相连神经元的活动(无需全局误差信号)
无监督: 不需要标签或目标输出
生物学合理性: 与真实神经元中观察到的突触可塑性相符
与反向传播的对比:
反向传播:全局误差信号、有监督、需要可微的损失函数
Hebbian 学习:局部活动、无监督、无需梯度计算
Amari-Hopfield 网络¶
Amari-Hopfield 网络 [21, 23] 是一种将模式存储为稳定吸引子状态的循环网络。当输入一个部分或带噪声的模式时,网络会收敛到最接近的已存储记忆。
用例: 联想记忆、模式补全、错误纠正
[1]:
from canns.models.brain_inspired import AmariHopfieldNetwork # :cite:p:`amari1977neural,hopfield1982neural`
# Create a Hopfield network with 16,384 neurons (128x128 flattened image)
model = AmariHopfieldNetwork(
num_neurons=128 * 128, # Image size when flattened
asyn=False, # Synchronous updates (all neurons update together)
activation="sign" # Binary activation: +1 or -1
)
参数:
num_neurons: 网络大小(必须与输入维度匹配)asyn: 异步(一次一个神经元)更新与同步更新activation: 二进制 Hopfield 网络用 “sign”,连续变体用 “tanh”
训练器框架¶
该库提供了一个 统一的训练器 API ,将训练逻辑抽象化:
模型 + 训练器 + 数据 → 训练好的模型
**理念** (源自设计哲学文档):
关注点分离: 模型定义动力学,训练器定义学习
可复用性: 同一训练器适用于具有兼容学习规则的不同模型
可组合性: 可堆叠训练器以进行多阶段训练
对于 Hebbian 学习,我们使用 HebbianTrainer:
[2]:
from canns.trainer import HebbianTrainer
trainer = HebbianTrainer(model)
关键方法:
trainer.train(data): 使用模式列表进行训练trainer.predict(pattern): 提取单个模式trainer.predict_batch(patterns): 批量提取(已编译以提升速度)
完整示例: 图像记忆¶
让我们训练一个 Amari-Hopfield 网络 [21] 来记忆 4 张图像,并从其损坏版本中将其提取出来。
第 1 步: 准备训练数据¶
[4]:
import numpy as np
import skimage.data
from skimage.color import rgb2gray
from skimage.transform import resize
from skimage.filters import threshold_mean
def preprocess_image(img, size=128):
"""Convert image to binary {-1, +1} pattern."""
# Convert to grayscale if needed
if img.ndim == 3:
img = rgb2gray(img)
# Resize to fixed size
img = resize(img, (size, size), anti_aliasing=True)
# Threshold to binary
thresh = threshold_mean(img)
binary = img > thresh
# Map to {-1, +1}
pattern = np.where(binary, 1.0, -1.0).astype(np.float32)
# Flatten to 1D
return pattern.reshape(size * size)
# Load example images from scikit-image
camera = preprocess_image(skimage.data.camera())
astronaut = preprocess_image(skimage.data.astronaut())
horse = preprocess_image(skimage.data.horse().astype(np.float32))
coffee = preprocess_image(skimage.data.coffee())
training_data = [camera, astronaut, horse, coffee]
print(f"Number of patterns: {len(training_data)}")
print(f"Pattern shape: {training_data[0].shape}") # (16384,) = 128*128
Number of patterns: 4
Pattern shape: (16384,)
为什么使用二进制值 {-1, +1}?
经典的 Amari-Hopfield 网络 [21] 使用二进制神经元
这简化了能量函数和更新规则
也存在实值扩展(使用
activation="tanh")
第 2 步: 创建模型和训练器¶
[5]:
from canns.models.brain_inspired import AmariHopfieldNetwork # :cite:p:`amari1977neural,hopfield1982neural`
from canns.trainer import HebbianTrainer
# Create Hopfield network (auto-initializes)
model = AmariHopfieldNetwork(
num_neurons=training_data[0].shape[0],
asyn=False,
activation="sign"
)
# Create Hebbian trainer
trainer = HebbianTrainer(model)
print("Model and trainer initialized!")
Model and trainer initialized!
第 3 步: 训练模型¶
[6]:
# Train on all patterns (this computes Hebbian weight matrix)
trainer.train(training_data)
print("Training complete! Patterns stored in weights.")
Training complete! Patterns stored in weights.
内部发生的事:
[ ]:
# Simplified version of Hebbian weight update
for pattern in training_data:
W += np.outer(pattern, pattern) # Hebbian: w_ij += x_i * x_j
W /= len(training_data) # Normalize
np.fill_diagonal(W, 0) # No self-connections
权重矩阵 W 现在已将所有训练模式编码为吸引子状态。
第 4 步:测试模式提取¶
创建训练图像的损坏版本:
[8]:
def corrupt_pattern(pattern, noise_level=0.3):
"""Randomly flip 30% of pixels."""
corrupted = np.copy(pattern)
num_flips = int(len(pattern) * noise_level)
flip_indices = np.random.choice(len(pattern), num_flips, replace=False)
corrupted[flip_indices] *= -1 # Flip sign
return corrupted
# Create test patterns (30% corrupted)
test_patterns = [corrupt_pattern(p, 0.3) for p in training_data]
print(f"Created {len(test_patterns)} corrupted test patterns")
Created 4 corrupted test patterns
第 5 步: 提取模式¶
[9]:
# Batch prediction (compiled for efficiency)
recalled = trainer.predict_batch(test_patterns, show_sample_progress=True)
print("Pattern recall complete!")
print(f"Recalled patterns shape: {np.array(recalled).shape}")
Processing samples: 100%|█████████████| 4/4 [00:04<00:00, 1.15s/it, sample=4/4]
Pattern recall complete!
Recalled patterns shape: (4, 16384)
发生了什么:
对于每个损坏的模式,网络会迭代其动力学过程
网络活动会收敛到最接近的已存储吸引子(原始图像)
结果就是被”清理干净”的模式
第 6 步: 可视化结果¶
[10]:
import matplotlib.pyplot as plt
def reshape_for_display(pattern, size=128):
"""Reshape 1D pattern back to 2D image."""
return pattern.reshape(size, size)
# Plot original, corrupted, and recalled patterns
fig, axes = plt.subplots(len(training_data), 3, figsize=(8, 10))
for i in range(len(training_data)):
# Column 1: Original training image
axes[i, 0].imshow(reshape_for_display(training_data[i]), cmap='gray')
axes[i, 0].axis('off')
if i == 0:
axes[i, 0].set_title('Original')
# Column 2: Corrupted test input
axes[i, 1].imshow(reshape_for_display(test_patterns[i]), cmap='gray')
axes[i, 1].axis('off')
if i == 0:
axes[i, 1].set_title('Corrupted (30%)')
# Column 3: Recalled output
axes[i, 2].imshow(reshape_for_display(recalled[i]), cmap='gray')
axes[i, 2].axis('off')
if i == 0:
axes[i, 2].set_title('Recalled')
plt.tight_layout()
plt.savefig('hopfield_memory_recall.png', dpi=150)
plt.show()
print("Visualization saved!")
Visualization saved!
预期结果:
原始图像是干净的
损坏的输入含有约 30% 的噪声
提取的输出与原始图像匹配(噪声已被纠正!)
完整可运行代码¶
以下是完整的单代码块示例:
[ ]:
import numpy as np
import skimage.data
from skimage.color import rgb2gray
from skimage.transform import resize
from skimage.filters import threshold_mean
from matplotlib import pyplot as plt
from canns.models.brain_inspired import AmariHopfieldNetwork # :cite:p:`amari1977neural,hopfield1982neural`
from canns.trainer import HebbianTrainer
np.random.seed(42)
# 1. Preprocess images
def preprocess_image(img, size=128):
if img.ndim == 3:
img = rgb2gray(img)
img = resize(img, (size, size), anti_aliasing=True)
thresh = threshold_mean(img)
binary = img > thresh
pattern = np.where(binary, 1.0, -1.0).astype(np.float32)
return pattern.reshape(size * size)
camera = preprocess_image(skimage.data.camera())
astronaut = preprocess_image(skimage.data.astronaut())
horse = preprocess_image(skimage.data.horse().astype(np.float32))
coffee = preprocess_image(skimage.data.coffee())
training_data = [camera, astronaut, horse, coffee]
# 2. Create model and trainer (auto-initializes)
model = AmariHopfieldNetwork(num_neurons=training_data[0].shape[0], asyn=False, activation="sign")
trainer = HebbianTrainer(model)
# 3. Train
trainer.train(training_data)
# 4. Create corrupted test patterns
def corrupt_pattern(pattern, noise_level=0.3):
corrupted = np.copy(pattern)
num_flips = int(len(pattern) * noise_level)
flip_indices = np.random.choice(len(pattern), num_flips, replace=False)
corrupted[flip_indices] *= -1
return corrupted
test_patterns = [corrupt_pattern(p, 0.3) for p in training_data]
# 5. Recall patterns
recalled = trainer.predict_batch(test_patterns, show_sample_progress=True)
# 6. Visualize
def reshape(pattern, size=128):
return pattern.reshape(size, size)
fig, axes = plt.subplots(len(training_data), 3, figsize=(8, 10))
for i in range(len(training_data)):
axes[i, 0].imshow(reshape(training_data[i]), cmap='gray')
axes[i, 0].axis('off')
axes[i, 1].imshow(reshape(test_patterns[i]), cmap='gray')
axes[i, 1].axis('off')
axes[i, 2].imshow(reshape(recalled[i]), cmap='gray')
axes[i, 2].axis('off')
axes[0, 0].set_title('Original')
axes[0, 1].set_title('Corrupted')
axes[0, 2].set_title('Recalled')
plt.tight_layout()
plt.savefig('hopfield_memory.png')
plt.show()
CANN 与深度学习训练对比¶
方面 |
Hebbian 学习 |
深度学习(人工神经网络) |
|---|---|---|
学习规则 |
局部的(Hebbian 学习、STDP [24]) |
全局的(反向传播) |
监督方式 |
无监督 |
有监督(通常) |
梯度 |
无需梯度计算 |
需要自动微分 |
训练数据 |
待记忆的模式 |
带标签的输入-输出对 |
目标 |
形成吸引子 |
最小化损失函数 |
生物学合理性 |
高 |
低 |
速度 |
快(可能实现一次性学习) |
慢(需要多个训练周期) |
容量 |
有限(对于 N 个神经元,约可存储 0.15N 个模式) |
非常大(过参数化) |
何时使用:
Hebbian 学习/CANN: 联想记忆、模式补全、神经科学建模
反向传播/人工神经网络: 分类、回归、大规模模式识别
训练器抽象¶
Trainer 框架为不同的学习范式提供了一致的接口:
[ ]:
# All trainers follow this pattern
trainer = SomeTrainer(model)
trainer.train(data)
output = trainer.predict(input)