RNN 不动点分析教程:FlipFlop 任务示例

目标: 本文档作为一份全面教程,详细介绍如何使用 FixedPointFinder 工具分析在 FlipFlop 任务上训练的 RNN。

结构: 1. 理论背景:什么是不动点? 2. 环境设置:导入脚本所需库 3. 组件定义:逐一介绍 FlipFlopDataFlipFlopRNNtrain_flipflop_rnn 函数 4. 核心用法:演示 FixedPointFinder 的具体使用流程

1. 理论背景:什么是不动点?

不动点 [29, 30] 是动力系统中的核心概念。对于一个 RNN,可将其视为函数 h_t+1 = F(h_t, u_t),其中 h 为隐藏状态,u 为输入。

输入 u 保持恒定 (例如,在 FlipFlop 任务中无脉冲输入的“记忆”阶段,u=0)时,系统演化为 h_t+1 = F(h_t)

一个 不动点 x* 是满足 x* = F(x*) 的状态。

  • 稳定不动点:表现为“吸引子”。若 RNN 状态 h 接近 x*,最终将收敛至 x*

  • 不稳定不动点:表现为“排斥子”或“鞍点”。

核心原理:训练 RNN 完成 FlipFlop 任务。训练成功后,RNN 会为每个需要“记忆”的状态(如 [+1, +1][+1, -1])构建一个 稳定不动点 。当输入 u=0 时,RNN 状态将自动流向并稳定于这些不动点,从而实现“记忆”功能。

本教程的 目的 是使用 FixedPointFinder 工具,发现 RNN 内部“隐藏”的所有稳定不动点。

2. 导入

导入 flipflop_fixed_points.py 脚本中使用的所有库。

[1]:
import brainpy as bp  # :cite:p:`wang2023brainpy`
import brainpy.math as bm
import jax  # :cite:p:`jax2018github`
import jax.numpy as jnp
import numpy as np
import random
from canns.analyzer.visualization import PlotConfig
from canns.analyzer.slow_points import FixedPointFinder, save_checkpoint, load_checkpoint, plot_fixed_points_2d, plot_fixed_points_3d

3. 组件定义:数据、模型与训练

本节定义 flipflop_fixed_points.py 脚本中的三个核心组件。

3.1 组件 1:FlipFlopData 类

这是 flipflop_fixed_points.py 脚本中的 FlipFlopData 类。

[2]:
class FlipFlopData:
    """Generator for flip-flop memory task data."""

    def __init__(self, n_bits=3, n_time=64, p=0.5, random_seed=0):
        """Initialize FlipFlopData generator.

        Args:
            n_bits: Number of memory channels.
            n_time: Number of timesteps per trial.
            p: Probability of input pulse at each timestep.
            random_seed: Random seed for reproducibility.
        """
        self.rng = np.random.RandomState(random_seed)
        self.n_time = n_time
        self.n_bits = n_bits
        self.p = p

    def generate_data(self, n_trials):
        """Generate flip-flop task data.

        Args:
            n_trials: Number of trials to generate.

        Returns:
            dict with 'inputs' and 'targets' arrays [n_trials x n_time x n_bits].
        """
        n_time = self.n_time
        n_bits = self.n_bits
        p = self.p

        # Generate unsigned input pulses
        unsigned_inputs = self.rng.binomial(1, p, [n_trials, n_time, n_bits])

        # Ensure every trial starts with a pulse
        unsigned_inputs[:, 0, :] = 1

        # Generate random signs {-1, +1}
        random_signs = 2 * self.rng.binomial(1, 0.5, [n_trials, n_time, n_bits]) - 1

        # Apply random signs
        inputs = unsigned_inputs * random_signs

        # Compute targets
        targets = np.zeros([n_trials, n_time, n_bits])
        for trial_idx in range(n_trials):
            for bit_idx in range(n_bits):
                input_seq = inputs[trial_idx, :, bit_idx]
                t_flip = np.where(input_seq != 0)[0]
                for flip_idx in range(len(t_flip)):
                    t_flip_i = t_flip[flip_idx]
                    targets[trial_idx, t_flip_i:, bit_idx] = inputs[
                        trial_idx, t_flip_i, bit_idx
                    ]

        return {
            "inputs": inputs.astype(np.float32),
            "targets": targets.astype(np.float32),
        }

3.2 组件 2:FlipFlopRNN 类

这是 flipflop_fixed_points.py 脚本中的 FlipFlopRNN 类。

用法说明FixedPointFinder原理 是寻找 x = F(x, u)。为计算 F(x, u),它将调用 rnn_model(inputs, hidden)FixedPointFinder 将传入形状为 [batch, 1, n_inputs]inputs 和形状为 [batch, n_hidden]hidden

因此,您的 __call__ 方法 必须 能处理 n_time == 1 的情况,并返回 (outputs, h_next)

请参见下方代码中的 if n_time == 1: 分支,该分支专为适配 FixedPointFinder 而设计。

[3]:
class FlipFlopRNN(bp.DynamicalSystem):
    """RNN model for the flip-flop memory task."""

    def __init__(self, n_inputs, n_hidden, n_outputs, rnn_type="gru", seed=0):
        """Initialize FlipFlop RNN.

        Args:
            n_inputs: Number of input channels.
            n_hidden: Number of hidden units.
            n_outputs: Number of output channels.
            rnn_type: Type of RNN cell ('tanh', 'gru').
            seed: Random seed for weight initialization.
        """
        super().__init__()
        self.n_inputs = n_inputs
        self.n_hidden = n_hidden
        self.n_outputs = n_outputs
        self.rnn_type = rnn_type.lower()

        # Initialize RNN cell parameters
        key = jax.random.PRNGKey(seed)
        k1, k2, k3, k4 = jax.random.split(key, 4)

        if rnn_type == "tanh":
            # Simple tanh RNN
            self.w_ih = bm.Variable(
                jax.random.normal(k1, (n_inputs, n_hidden)) * 0.1
            )
            self.w_hh = bm.Variable(
                jax.random.normal(k2, (n_hidden, n_hidden)) * 0.5
            )
            self.b_h = bm.Variable(jnp.zeros(n_hidden))
        elif rnn_type == "gru":
            # GRU cell
            self.w_ir = bm.Variable(
                jax.random.normal(k1, (n_inputs, n_hidden)) * 0.1
            )
            self.w_hr = bm.Variable(
                jax.random.normal(k2, (n_hidden, n_hidden)) * 0.5
            )
            self.w_iz = bm.Variable(
                jax.random.normal(k3, (n_inputs, n_hidden)) * 0.1
            )
            self.w_hz = bm.Variable(
                jax.random.normal(k4, (n_hidden, n_hidden)) * 0.5
            )
            k5, k6, k7, k8 = jax.random.split(k4, 4)
            self.w_in = bm.Variable(
                jax.random.normal(k5, (n_inputs, n_hidden)) * 0.1
            )
            self.w_hn = bm.Variable(
                jax.random.normal(k6, (n_hidden, n_hidden)) * 0.5
            )
            self.b_r = bm.Variable(jnp.zeros(n_hidden))
            self.b_z = bm.Variable(jnp.zeros(n_hidden))
            self.b_n = bm.Variable(jnp.zeros(n_hidden))
        else:
            raise ValueError(f"Unsupported rnn_type: {rnn_type}")

        # Readout layer
        self.w_out = bm.Variable(
            jax.random.normal(k3, (n_hidden, n_outputs)) * 0.1
        )
        self.b_out = bm.Variable(jnp.zeros(n_outputs))

        # Initial hidden state
        self.h0 = bm.Variable(jnp.zeros(n_hidden))

    def step(self, x_t, h):
        """Single RNN step.

        Args:
            x_t: [batch_size x n_inputs] input at time t.
            h: [batch_size x n_hidden] hidden state.

        Returns:
            h_next: [batch_size x n_hidden] next hidden state.
        """
        if self.rnn_type == "tanh":
            # Simple tanh RNN step
            h_next = jnp.tanh(
                x_t @ self.w_ih.value + h @ self.w_hh.value + self.b_h.value
            )
        elif self.rnn_type == "gru":
            # GRU step
            r = jax.nn.sigmoid(
                x_t @ self.w_ir.value + h @ self.w_hr.value + self.b_r.value
            )
            z = jax.nn.sigmoid(
                x_t @ self.w_iz.value + h @ self.w_hz.value + self.b_z.value
            )
            n = jnp.tanh(
                x_t @ self.w_in.value + (r * h) @ self.w_hn.value + self.b_n.value
            )
            h_next = (1 - z) * n + z * h
        else:
            raise ValueError(f"Unknown rnn_type: {self.rnn_type}")

        return h_next

    def __call__(self, inputs, hidden=None):
        """Forward pass through the RNN. Optimized with jax.lax.scan."""
        batch_size = inputs.shape[0]
        n_time = inputs.shape[1]

        # Initialize hidden state
        if hidden is None:
            h = jnp.tile(self.h0.value, (batch_size, 1))
        else:
            h = hidden

        # Single-step computation mode for the fixed-point finder
        if n_time == 1:
            x_t = inputs[:, 0, :]
            h_next = self.step(x_t, h)
            y = h_next @ self.w_out.value + self.b_out.value
            return y[:, None, :], h_next

        # Full sequence case
        def scan_fn(carry, x_t):
            """Single-step scan function"""
            h_prev = carry
            h_next = self.step(x_t, h_prev)
            y_t = h_next @ self.w_out.value + self.b_out.value
            return h_next, (y_t, h_next)

        # (batch, time, features) -> (time, batch, features)
        inputs_transposed = inputs.transpose(1, 0, 2)

        # Run the scan
        _, (outputs_seq, hiddens_seq) = jax.lax.scan(scan_fn, h, inputs_transposed)

        outputs = outputs_seq.transpose(1, 0, 2)
        hiddens = hiddens_seq.transpose(1, 0, 2)

        return outputs, hiddens

3.3 组件 3:train_flipflop_rnn 函数

这是 flipflop_fixed_points.py 脚本中的 train_flipflop_rnn 函数。

主要更新

  • 使用现代 BrainPy 优化器 API:bp.optimizers.Adam(lr=..., train_vars=...)

  • 处理参数名称映射(vars() 返回完整名称如 'FlipFlopRNN0.w_ih'

  • 不再使用 braintools 或已弃用的 register_trainable_weights()

[4]:
def train_flipflop_rnn(rnn, train_data, valid_data,
                       learning_rate=0.08,
                       batch_size=128,
                       max_epochs=1000,
                       min_loss=1e-4,
                       print_every=10):
    print("\n" + "=" * 70)
    print("Training FlipFlop RNN (Using brainpy optimizer)")
    print("=" * 70)

    # Prepare data
    train_inputs = jnp.array(train_data['inputs'])
    train_targets = jnp.array(train_data['targets'])
    valid_inputs = jnp.array(valid_data['inputs'])
    valid_targets = jnp.array(valid_data['targets'])
    n_train = train_inputs.shape[0]
    n_batches = n_train // batch_size

    # Get trainable variables from the model
    # Note: vars() returns keys like 'FlipFlopRNN0.w_ih', we need just 'w_ih' for computation
    train_vars = {name: var for name, var in rnn.vars().items() if isinstance(var, bm.Variable)}
    # Create mapping between short names and full names
    name_mapping = {name.split('.')[-1]: name for name in train_vars.keys()}
    # Extract just the parameter name (after the last dot) for gradient computation
    params = {name.split('.')[-1]: var.value for name, var in train_vars.items()}

    # Initialize optimizer with train_vars parameter (modern brainpy API)
    optimizer = bp.optimizers.Adam(lr=learning_rate, train_vars=train_vars)

    # Define JIT-compiled gradient step
    @jax.jit
    def grad_step(params, batch_inputs, batch_targets):
        """Pure function to compute loss and gradients"""
        def forward_pass(p, inputs):
            batch_size = inputs.shape[0]
            h = jnp.tile(p['h0'], (batch_size, 1))

            def scan_fn(carry, x_t):
                h_prev = carry
                if rnn.rnn_type == "tanh":
                    h_next = jnp.tanh(x_t @ p['w_ih'] + h_prev @ p['w_hh'] + p['b_h'])
                elif rnn.rnn_type == "gru":
                    r = jax.nn.sigmoid(x_t @ p['w_ir'] + h_prev @ p['w_hr'] + p['b_r'])
                    z = jax.nn.sigmoid(x_t @ p['w_iz'] + h_prev @ p['w_hz'] + p['b_z'])
                    n = jnp.tanh(x_t @ p['w_in'] + (r * h_prev) @ p['w_hn'] + p['b_n'])
                    h_next = (1 - z) * n + z * h_prev
                else:
                    h_next = h_prev
                y_t = h_next @ p['w_out'] + p['b_out']
                return h_next, y_t

            inputs_transposed = inputs.transpose(1, 0, 2)
            _, outputs_seq = jax.lax.scan(scan_fn, h, inputs_transposed)
            outputs = outputs_seq.transpose(1, 0, 2)
            return outputs

        def loss_fn(p):
            outputs = forward_pass(p, batch_inputs)
            return jnp.mean((outputs - batch_targets) ** 2)

        loss_val, grads = jax.value_and_grad(loss_fn)(params)
        return loss_val, grads

    losses = []
    print("\nTraining parameters:")
    print(f"  Batch size: {batch_size}")
    print(f"  Learning rate:{learning_rate:.6f} (Fixed)")

    for epoch in range(max_epochs):
        perm = np.random.permutation(n_train)
        epoch_loss = 0.0
        for batch_idx in range(n_batches):
            start_idx = batch_idx * batch_size
            end_idx = start_idx + batch_size
            batch_inputs = train_inputs[perm[start_idx:end_idx]]
            batch_targets = train_targets[perm[start_idx:end_idx]]
            loss_val, grads_short = grad_step(params, batch_inputs, batch_targets)
            # Map gradients back to full names for optimizer
            grads = {name_mapping[short_name]: grad for short_name, grad in grads_short.items()}
            optimizer.update(grads)
            # Update params with current variable values (extract parameter names)
            params = {name.split('.')[-1]: var.value for name, var in train_vars.items()}
            epoch_loss += float(loss_val)
        epoch_loss /= n_batches
        losses.append(epoch_loss)

        if epoch % print_every == 0:
            valid_outputs, _ = rnn(valid_inputs)
            valid_loss = float(jnp.mean((valid_outputs - valid_targets) ** 2))
            print(f"Epoch {epoch:4d}: train_loss = {epoch_loss:.6f}, "
                  f"valid_loss = {valid_loss:.6f}")
        if epoch_loss < min_loss:
            print(f"\nReached target loss {min_loss:.2e} at epoch {epoch}")
            break

    # Training complete
    valid_outputs, _ = rnn(valid_inputs)
    final_valid_loss = float(jnp.mean((valid_outputs - valid_targets) ** 2))
    print("\n" + "=" * 70)
    print("Training Complete!")
    print("=" * 70)
    print(f"Final training loss: {epoch_loss:.6f}")
    print(f"Final validation loss: {final_valid_loss:.6f}")
    print(f"Total epochs: {epoch + 1}")
    return losses

4. 核心用法:训练与寻找不动点

我们将复现 flipflop_fixed_points.py 脚本中 main 函数与 if __name__ == "__main__": 块的逻辑。

我们将:

  1. 定义任务配置

  2. 设置参数并生成数据

  3. 训练或加载(若存在)模型

  4. 初始化并运行 FixedPointFinder

  5. 打印结果并可视化

4.1 第 1 步:定义配置与参数

此部分源自 flipflop_fixed_points.py 中的全局 TASK_CONFIGS 字典、if __name__ == "__main__": 块,以及 main 函数的起始部分。

[5]:
# Configuration Dictionary
TASK_CONFIGS = {
    "2_bit": {
        "n_bits": 2,
        "n_hidden": 3,
        "n_trials_train": 512,
        "n_inits":1024,
    },
    "3_bit": {
        "n_bits": 3,
        "n_hidden": 4,
        "n_trials_train": 512,
        "n_inits":1024,
    },
    "4_bit": {
        "n_bits": 4,
        "n_hidden": 6,
        "n_trials_train": 512,
        "n_inits":1024,
    },
}

# --- Set parameters ---
# (This part is from the if __name__ == "__main__" block in the original script)
config_to_run = "3_bit"  # Specify which configuration to run
seed_to_use = 42       # Use a fixed seed

config_name = config_to_run
seed = seed_to_use

# (This part is from the main function in the original script)
if config_name not in TASK_CONFIGS:
    raise ValueError(f"Unknown config_name: {config_name}. Available: {list(TASK_CONFIGS.keys())}")
config = TASK_CONFIGS[config_name]

# Set random seeds
np.random.seed(seed)
random.seed(seed)

print(f"\n--- Running FlipFlop Task ({config_name}) ---")
print(f"Seed: {seed}")

n_bits = config["n_bits"]
n_hidden = config["n_hidden"]
n_trials_train = config["n_trials_train"]
n_inits = config["n_inits"]

n_time = 64
n_trials_valid = 128
n_trials_test = 128
rnn_type = "tanh"
learning_rate = 0.08
batch_size = 128
max_epochs = 500 # (Originally 1000, 500 runs faster in Notebook)
min_loss = 1e-4

--- Running FlipFlop Task (3_bit) ---
Seed: 42

4.2 第 2 步:生成数据并训练模型

此部分源自 flipflop_fixed_points.py 中的 main 函数。

[6]:
# Generate data
data_gen = FlipFlopData(n_bits=n_bits, n_time=n_time, p=0.5, random_seed=seed)
train_data = data_gen.generate_data(n_trials_train)
valid_data = data_gen.generate_data(n_trials_valid)
test_data = data_gen.generate_data(n_trials_test)

# Create RNN model
rnn = FlipFlopRNN(n_inputs=n_bits, n_hidden=n_hidden, n_outputs=n_bits, rnn_type=rnn_type, seed=seed)

# Check for checkpoint
checkpoint_path = f"flipflop_rnn_{config_name}_checkpoint.msgpack"
if load_checkpoint(rnn, checkpoint_path):
    print(f"Loaded model from checkpoint: {checkpoint_path}")
else:
    # Train the RNN
    print(f"No checkpoint found ({checkpoint_path}). Training...")
    losses = train_flipflop_rnn(
        rnn,
        train_data,
        valid_data,
        learning_rate=learning_rate,
        batch_size=batch_size,
        max_epochs=max_epochs,
        min_loss=min_loss,
        print_every=10
    )
No checkpoint found (flipflop_rnn_3_bit_checkpoint.msgpack). Training...

======================================================================
Training FlipFlop RNN (Using brainpy optimizer)
======================================================================

Training parameters:
  Batch size: 128
  Learning rate:0.080000 (Fixed)
Epoch    0: train_loss = 0.934704, valid_loss = 0.716311
Epoch   10: train_loss = 0.006317, valid_loss = 0.006368
Epoch   20: train_loss = 0.000650, valid_loss = 0.000600
Epoch   30: train_loss = 0.000387, valid_loss = 0.000375
Epoch   40: train_loss = 0.000302, valid_loss = 0.000295
Epoch   50: train_loss = 0.000257, valid_loss = 0.000253
Epoch   60: train_loss = 0.000229, valid_loss = 0.000227
Epoch   70: train_loss = 0.000209, valid_loss = 0.000207
Epoch   80: train_loss = 0.000193, valid_loss = 0.000191
Epoch   90: train_loss = 0.000179, valid_loss = 0.000178
Epoch  100: train_loss = 0.000167, valid_loss = 0.000166
Epoch  110: train_loss = 0.000157, valid_loss = 0.000156
Epoch  120: train_loss = 0.000147, valid_loss = 0.000147
Epoch  130: train_loss = 0.000139, valid_loss = 0.000138
Epoch  140: train_loss = 0.000131, valid_loss = 0.000131
Epoch  150: train_loss = 0.000124, valid_loss = 0.000124
Epoch  160: train_loss = 0.000118, valid_loss = 0.000117
Epoch  170: train_loss = 0.000112, valid_loss = 0.000111
Epoch  180: train_loss = 0.000106, valid_loss = 0.000106
Epoch  190: train_loss = 0.000101, valid_loss = 0.000101

Reached target loss 1.00e-04 at epoch 193

======================================================================
Training Complete!
======================================================================
Final training loss: 0.000100
Final validation loss: 0.000100
Total epochs: 194

4.3 第 3 步:运行不动点分析

这是 FixedPointFinder具体用法 ,源自 main 函数。

用法说明

  1. 收集状态轨迹hiddens_npFixedPointFinder 将从这些“真实”状态中 采样 初始点。

  2. 初始化 FixedPointFinder

    • rnn_model:传入 rnn 实例

    • do_compute_jacobians=True:必须设为 True,用于计算雅可比矩阵 J = dF/dx

    • do_decompose_jacobians=True:必须设为 True,用于计算 J 的特征值以判断 稳定性

  3. 运行 find_fixed_points

    • state_traj:传入 hiddens_np

    • inputs:我们希望寻找“记忆”状态,即 无输入 时的不动点,因此传入恒定零向量 constant_input

[7]:
# Fixed Point Analysis
print("\n--- Fixed Point Analysis ---")
inputs_jax = jnp.array(test_data["inputs"])
outputs, hiddens = rnn(inputs_jax)
hiddens_np = np.array(hiddens)

# Find fixed points
finder = FixedPointFinder(
    rnn,
    method="joint",
    max_iters=5000,
    lr_init=0.02,
    tol_q=1e-4,
    final_q_threshold=1e-6,
    tol_unique=1e-2,
    do_compute_jacobians=True,
    do_decompose_jacobians=True,
    outlier_distance_scale=10.0,
    verbose=True,
    super_verbose=True,
)

constant_input = np.zeros((1, n_bits), dtype=np.float32)

unique_fps, all_fps = finder.find_fixed_points(
    state_traj=hiddens_np,
    inputs=constant_input,
    n_inits=n_inits,
    noise_scale=0.4,
)

--- Fixed Point Analysis ---

Searching for fixed points from 1024 initial states.

        Finding fixed points via joint optimization.
/var/folders/x0/_jqxxbbn0rsdn6b4h6fxbrjr0000gn/T/ipykernel_3897/1298414900.py:25: UserWarning: Joint optimization with n_inits=1024 may be inefficient and use excessive memory. Consider using sequential optimization or reducing n_inits.
  unique_fps, all_fps = finder.find_fixed_points(
        Iter: 100, q = 2.76e-04 +/- 1.94e-03, dq = 8.52e-06 +/- 4.71e-05, lr = 2.00e-02, avg iter time = 1.18e-02 sec.
        Iter: 200, q = 4.19e-05 +/- 9.01e-04, dq = 3.81e-07 +/- 6.37e-06, lr = 2.00e-02, avg iter time = 8.01e-03 sec.
        Iter: 300, q = 4.77e-06 +/- 1.16e-04, dq = 1.86e-07 +/- 4.34e-06, lr = 2.00e-02, avg iter time = 6.71e-03 sec.
        Optimization complete to desired tolerance.
                384 iters, q = 1.06e-07 +/- 3.12e-06, dq = 5.27e-09 +/- 1.49e-07, lr = 2.00e-02, avg iter time = 6.15e-03 sec
        Identified 26 unique fixed points.
        Computing recurrent Jacobian at 26 unique fixed points.
        Computing input Jacobian at 26 unique fixed points.
Decomposing 26 Jacobians...
Found 8 stable and 18 unstable fixed points.
        Applying final q-value filter (q < 1.0e-06)...
                Excluded 1 low-quality fixed points.
                25 high-quality fixed points remain.
        Fixed point finding complete.

4.4 结果分析与可视化

find_fixed_points 返回两个对象:

  • all_fps:包含从 n_inits 个初始点出发找到的所有结果

  • unique_fps我们最关注的结果。经 tol_unique 过滤后的非重复不动点集合

如何解读

  • unique_fps.n:找到的独特不动点数量

  • unique_fps.qstarq 值,越接近 0 越优

  • unique_fps.is_stable(关键) 是否为稳定不动点

对于 N-bit 任务,我们期望找到 2^N 个稳定不动点 (代表 2^N 种记忆状态)

下方代码单元整合了 flipflop_fixed_points.py 脚本中 main 函数末尾与 if __name__ == "__main__": 块的最后一行,用于打印全部分析结果并生成图表。

[8]:
# Print results
print("\n--- Fixed Point Analysis Results ---")
unique_fps.print_summary()

if unique_fps.n > 0:
    print(f"\nDetailed Fixed Point Information (Top 10):")
    print(f"{'#':<4} {'q-value':<12} {'Stability':<12} {'Max |eig|':<12}")
    print("-" * 45)
    for i in range(min(10, unique_fps.n)):
        stability_str = "Stable" if unique_fps.is_stable[i] else "Unstable"
        max_eig = np.abs(unique_fps.eigval_J_xstar[i, 0])
        print(
            f"{i + 1:<4} {unique_fps.qstar[i]:<12.2e} {stability_str:<12} {max_eig:<12.4f}"
        )

    # Visualize fixed points - 2D
    config_2d = PlotConfig(
        title=f"FlipFlop Fixed Points ({config_name} - 2D PCA)",
        xlabel="PC 1", ylabel="PC 2", figsize=(10, 8),
        show=True
    )
    plot_fixed_points_2d(unique_fps, hiddens_np, config=config_2d)

    # Visualize fixed points - 3D
    config_3d = PlotConfig(
        title=f"FlipFlop Fixed Points ({config_name} - 3D PCA)",
        figsize=(12, 10),
        show=True
    )
    plot_fixed_points_3d(
        unique_fps, hiddens_np, config=config_3d,
        plot_batch_idx=list(range(30)), plot_start_time=10
    )

print("\n--- Analysis complete ---")

--- Fixed Point Analysis Results ---

=== Fixed Points Summary ===
Number of fixed points: 25
State dimension: 4
Input dimension: 3

q values: min=0.00e+00, median=6.78e-19, max=9.87e-15
Iterations: min=384, median=384, max=384

Stable fixed points: 8 / 25

Detailed Fixed Point Information (Top 10):
#    q-value      Stability    Max |eig|
---------------------------------------------
1    1.78e-15     Stable       0.2418
2    0.00e+00     Stable       0.2308
3    0.00e+00     Stable       0.2291
4    6.78e-19     Unstable     1.7985
5    1.78e-15     Unstable     2.1942
6    0.00e+00     Stable       0.2417
7    0.00e+00     Stable       0.2435
8    0.00e+00     Stable       0.2291
9    0.00e+00     Stable       0.2436
10   1.78e-15     Stable       0.2307
../../../_images/zh_3_full_detail_tutorials_02_data_analysis_flipflop_tutorial_19_1.png
  PCA explained variance: [0.46209928 0.27972662 0.25745255]
  Total variance explained: 99.93%
../../../_images/zh_3_full_detail_tutorials_02_data_analysis_flipflop_tutorial_19_3.png

--- Analysis complete ---

5. 多配置对比:2-bit、3-bit、4-bit

以下我们将运行全部三种配置,展示不同复杂度任务的不动点分析结果。

预期结果

  • 2-bit:4 个稳定不动点(2² = 4 种记忆状态)

  • 3-bit:8 个稳定不动点(2³ = 8 种记忆状态)

  • 4-bit:16 个稳定不动点(2⁴ = 16 种记忆状态)

[9]:
import matplotlib.pyplot as plt

def run_flipflop_analysis(config_name, seed=42):
    """Run complete analysis pipeline for a single configuration"""
    config = TASK_CONFIGS[config_name]
    n_bits = config["n_bits"]
    n_hidden = config["n_hidden"]
    n_trials_train = config["n_trials_train"]
    n_inits = config["n_inits"]

    # Set random seeds
    np.random.seed(seed)
    random.seed(seed)

    print(f"\n{'='*60}")
    print(f"Configuration: {config_name} ({n_bits} bits, {n_hidden} hidden units)")
    print(f"{'='*60}")

    # Generate data
    data_gen = FlipFlopData(n_bits=n_bits, n_time=64, p=0.5, random_seed=seed)
    train_data = data_gen.generate_data(n_trials_train)
    valid_data = data_gen.generate_data(128)
    test_data = data_gen.generate_data(128)

    # Create and train model
    rnn = FlipFlopRNN(n_inputs=n_bits, n_hidden=n_hidden,
                      n_outputs=n_bits, rnn_type="tanh", seed=seed)

    checkpoint_path = f"flipflop_rnn_{config_name}_checkpoint.msgpack"
    if not load_checkpoint(rnn, checkpoint_path):
        print(f"Training model...")
        train_flipflop_rnn(rnn, train_data, valid_data,
                          learning_rate=0.08, batch_size=128,
                          max_epochs=500, min_loss=1e-4, print_every=50)
    else:
        print(f"Loaded model from checkpoint: {checkpoint_path}")

    # Get hidden state trajectory
    inputs_jax = jnp.array(test_data["inputs"])
    outputs, hiddens = rnn(inputs_jax)
    hiddens_np = np.array(hiddens)

    # Fixed point analysis
    finder = FixedPointFinder(
        rnn, method="joint", max_iters=5000, lr_init=0.02,
        tol_q=1e-4, final_q_threshold=1e-6, tol_unique=1e-2,
        do_compute_jacobians=True, do_decompose_jacobians=True,
        outlier_distance_scale=10.0, verbose=True, super_verbose=False,
    )

    constant_input = np.zeros((1, n_bits), dtype=np.float32)
    unique_fps, _ = finder.find_fixed_points(
        state_traj=hiddens_np, inputs=constant_input,
        n_inits=n_inits, noise_scale=0.4,
    )

    return unique_fps, hiddens_np, config_name

# Store results for all configurations
all_results = {}
for cfg in ["2_bit", "3_bit", "4_bit"]:
    unique_fps, hiddens_np, name = run_flipflop_analysis(cfg, seed=43)
    all_results[cfg] = {"fps": unique_fps, "hiddens": hiddens_np}

    # Print summary
    n_stable = np.sum(unique_fps.is_stable) if unique_fps.n > 0 else 0
    expected = 2 ** int(cfg[0])
    print(f"\nResults: Found {unique_fps.n} fixed points, {n_stable} stable (expected: {expected})")

============================================================
Configuration: 2_bit (2 bits, 3 hidden units)
============================================================
Training model...

======================================================================
Training FlipFlop RNN (Using brainpy optimizer)
======================================================================

Training parameters:
  Batch size: 128
  Learning rate:0.080000 (Fixed)
Epoch    0: train_loss = 0.933466, valid_loss = 0.676156
Epoch   50: train_loss = 0.000569, valid_loss = 0.000562
Epoch  100: train_loss = 0.000369, valid_loss = 0.000366
Epoch  150: train_loss = 0.000271, valid_loss = 0.000268
Epoch  200: train_loss = 0.000210, valid_loss = 0.000208
Epoch  250: train_loss = 0.000169, valid_loss = 0.000168
Epoch  300: train_loss = 0.000139, valid_loss = 0.000138
Epoch  350: train_loss = 0.000116, valid_loss = 0.000115

Reached target loss 1.00e-04 at epoch 395

======================================================================
Training Complete!
======================================================================
Final training loss: 0.000100
Final validation loss: 0.000099
Total epochs: 396

Searching for fixed points from 1024 initial states.

        Finding fixed points via joint optimization.
/var/folders/x0/_jqxxbbn0rsdn6b4h6fxbrjr0000gn/T/ipykernel_3897/1594189833.py:52: UserWarning: Joint optimization with n_inits=1024 may be inefficient and use excessive memory. Consider using sequential optimization or reducing n_inits.
  unique_fps, _ = finder.find_fixed_points(
        Optimization complete to desired tolerance.
                170 iters, q = 9.62e-08 +/- 2.84e-06, dq = 1.36e-07 +/- 4.32e-06, lr = 1.63e-02, avg iter time = 7.56e-03 sec
        Identified 9 unique fixed points.
        Computing recurrent Jacobian at 9 unique fixed points.
        Computing input Jacobian at 9 unique fixed points.
Decomposing 9 Jacobians...
Found 4 stable and 5 unstable fixed points.
        Applying final q-value filter (q < 1.0e-06)...
                9 high-quality fixed points remain.
        Fixed point finding complete.


Results: Found 9 fixed points, 4 stable (expected: 4)

============================================================
Configuration: 3_bit (3 bits, 4 hidden units)
============================================================
Training model...

======================================================================
Training FlipFlop RNN (Using brainpy optimizer)
======================================================================

Training parameters:
  Batch size: 128
  Learning rate:0.080000 (Fixed)
Epoch    0: train_loss = 0.917904, valid_loss = 0.676378
Epoch   50: train_loss = 0.000278, valid_loss = 0.000279
Epoch  100: train_loss = 0.000161, valid_loss = 0.000163
Epoch  150: train_loss = 0.000116, valid_loss = 0.000117

Reached target loss 1.00e-04 at epoch 180

======================================================================
Training Complete!
======================================================================
Final training loss: 0.000100
Final validation loss: 0.000101
Total epochs: 181

Searching for fixed points from 1024 initial states.

        Finding fixed points via joint optimization.
        Optimization complete to desired tolerance.
                162 iters, q = 1.36e-07 +/- 3.03e-06, dq = 1.90e-08 +/- 4.74e-07, lr = 2.00e-02, avg iter time = 4.13e-03 sec
        Identified 27 unique fixed points.
        Computing recurrent Jacobian at 27 unique fixed points.
        Computing input Jacobian at 27 unique fixed points.
Decomposing 27 Jacobians...
Found 9 stable and 18 unstable fixed points.
        Applying final q-value filter (q < 1.0e-06)...
                Excluded 1 low-quality fixed points.
                26 high-quality fixed points remain.
        Fixed point finding complete.


Results: Found 26 fixed points, 8 stable (expected: 8)

============================================================
Configuration: 4_bit (4 bits, 6 hidden units)
============================================================
Training model...

======================================================================
Training FlipFlop RNN (Using brainpy optimizer)
======================================================================

Training parameters:
  Batch size: 128
  Learning rate:0.080000 (Fixed)
Epoch    0: train_loss = 0.915913, valid_loss = 0.698172
Epoch   50: train_loss = 0.000216, valid_loss = 0.000215
Epoch  100: train_loss = 0.000131, valid_loss = 0.000130

Reached target loss 1.00e-04 at epoch 140

======================================================================
Training Complete!
======================================================================
Final training loss: 0.000100
Final validation loss: 0.000100
Total epochs: 141

Searching for fixed points from 1024 initial states.

        Finding fixed points via joint optimization.
        Optimization complete to desired tolerance.
                333 iters, q = 4.02e-07 +/- 4.73e-06, dq = 1.24e-08 +/- 1.45e-07, lr = 2.00e-02, avg iter time = 5.46e-03 sec
        Identified 74 unique fixed points.
        Computing recurrent Jacobian at 74 unique fixed points.
        Computing input Jacobian at 74 unique fixed points.
Decomposing 74 Jacobians...
Found 22 stable and 52 unstable fixed points.
        Applying final q-value filter (q < 1.0e-06)...
                Excluded 7 low-quality fixed points.
                67 high-quality fixed points remain.
        Fixed point finding complete.


Results: Found 67 fixed points, 16 stable (expected: 16)

5.1 2D 可视化对比

展示三种配置的 2D PCA 投影,可直观观察到随任务复杂度增加,不动点数量增多。

[10]:
# 2D visualization - display each configuration separately
for cfg in ["2_bit", "3_bit", "4_bit"]:
    result = all_results[cfg]
    unique_fps = result["fps"]
    hiddens_np = result["hiddens"]

    n_bits = int(cfg[0])
    n_stable = np.sum(unique_fps.is_stable) if unique_fps.n > 0 else 0

    config_2d = PlotConfig(
        title=f"FlipFlop {cfg}: {n_stable} stable fixed points (2D PCA)",
        xlabel="PC 1", ylabel="PC 2",
        figsize=(8, 6),
        show=True
    )

    plot_fixed_points_2d(unique_fps, hiddens_np, config=config_2d)
../../../_images/zh_3_full_detail_tutorials_02_data_analysis_flipflop_tutorial_23_0.png
../../../_images/zh_3_full_detail_tutorials_02_data_analysis_flipflop_tutorial_23_1.png
../../../_images/zh_3_full_detail_tutorials_02_data_analysis_flipflop_tutorial_23_2.png

5.2 3D 可视化对比

3D PCA 投影展示隐藏状态轨迹与不动点在三维空间中的分布,更清晰呈现 RNN 的动力学结构。

[11]:
# 3D visualization - display each configuration separately
for cfg in ["2_bit", "3_bit", "4_bit"]:
    result = all_results[cfg]
    unique_fps = result["fps"]
    hiddens_np = result["hiddens"]

    n_bits = int(cfg[0])
    n_stable = np.sum(unique_fps.is_stable) if unique_fps.n > 0 else 0

    config_3d = PlotConfig(
        title=f"FlipFlop {cfg}: {n_stable} stable fixed points (3D PCA)",
        figsize=(10, 8),
        show=True
    )

    plot_fixed_points_3d(
        unique_fps, hiddens_np, config=config_3d,
        plot_batch_idx=list(range(20)), plot_start_time=10
    )
  PCA explained variance: [6.6613883e-01 3.3384740e-01 1.3795699e-05]
  Total variance explained: 100.00%
../../../_images/zh_3_full_detail_tutorials_02_data_analysis_flipflop_tutorial_25_1.png
  PCA explained variance: [0.37013304 0.3110425  0.3062212 ]
  Total variance explained: 98.74%
../../../_images/zh_3_full_detail_tutorials_02_data_analysis_flipflop_tutorial_25_3.png
  PCA explained variance: [0.3934802  0.22392422 0.18720938]
  Total variance explained: 80.46%
../../../_images/zh_3_full_detail_tutorials_02_data_analysis_flipflop_tutorial_25_5.png

6. 总结

本教程展示了如何使用 FixedPointFinder 分析 RNN 的动力学结构:

  1. FlipFlop 任务:RNN 需在多个二进制通道上记忆状态

  2. 不动点分析:发现 RNN 用于“记忆”的稳定状态 [29, 30]

  3. 可视化:通过 PCA 降维展示不动点在隐藏状态空间中的分布

关键发现

  • 对于 N-bit 任务,RNN 学会创建 2^N 个稳定不动点

  • 这些不动点对应不同的记忆状态组合

  • 不动点分析是理解 RNN 内部计算机制的有力工具