import math
import brainstate
import brainunit as u
import jax
from brainunit import Quantity
from matplotlib import pyplot as plt
from ...typing import time_type
from ._base import BasicModel
__all__ = [
# Base Model
"BaseCANN",
"BaseCANN1D",
"BaseCANN2D",
# CANN 1D Models
"CANN1D",
"CANN1D_SFA",
# CANN 2D Models
"CANN2D",
"CANN2D_SFA",
]
[docs]
class BaseCANN(BasicModel):
"""
Base class for Continuous Attractor Neural Network (CANN) models.
This class sets up the fundamental properties of the network, including
neuronal properties, feature space, and the connectivity matrix, which
are shared by different CANN model variations.
"""
def __init__(
self,
shape: int | tuple[int, ...],
**kwargs,
):
"""
Initializes the base CANN model.
Args:
shape (int or tuple): The number of neurons in the network. If an int is provided,
it will be converted to a single-element tuple. If a tuple is provided,
it defines the shape of the network (e.g., (length, length) for 2D).
Internally, shape is always stored as a tuple.
**kwargs: Additional keyword arguments passed to the parent BasicModel.
"""
if isinstance(shape, int):
self.shape = (shape,)
elif isinstance(shape, tuple):
self.shape = shape
else:
raise TypeError("shape must be an int or a tuple of ints")
super().__init__(math.prod(self.shape), **kwargs)
[docs]
def make_conn(self):
"""
Constructs the connectivity matrix for the CANN model.
This method should be implemented in subclasses to define how neurons
are connected based on their feature preferences.
Returns:
Array: A connectivity matrix defining the synaptic strengths between neurons.
"""
raise NotImplementedError("This method should be implemented in subclasses.")
[docs]
def get_stimulus_by_pos(self, pos):
"""
Generates an external stimulus based on a given position in the feature space.
This method should be implemented in subclasses to define how the stimulus is shaped.
Args:
pos (float or Array): The position in the feature space where the stimulus is centered.
Returns:
Array: An array of stimulus values for each neuron.
"""
raise NotImplementedError("This method should be implemented in subclasses.")
[docs]
class BaseCANN1D(BaseCANN):
"""
Base class for 1D Continuous Attractor Neural Network (CANN) models.
This class sets up the fundamental properties of the network, including
neuronal properties, feature space, and the connectivity matrix, which
are shared by different CANN model variations.
"""
def __init__(
self,
num: int,
tau: time_type = 1.0,
k: float = 8.1,
a: float = 0.5,
A: float = 10,
J0: float = 4.0,
z_min: float = -u.math.pi,
z_max: float = u.math.pi,
**kwargs,
):
"""
Initializes the base 1D CANN model.
Args:
num (int): The number of neurons in the network.
tau (float): The synaptic time constant, controlling how quickly the membrane potential changes.
k (float): A parameter controlling the strength of the global inhibition.
a (float): The half-width of the excitatory connection range. It defines the "spread" of local connections.
A (float): The magnitude (amplitude) of the external stimulus.
J0 (float): The maximum connection strength between neurons.
z_min (float): The minimum value of the feature space (e.g., -pi for an angle).
z_max (float): The maximum value of the feature space (e.g., +pi for an angle).
**kwargs: Additional keyword arguments passed to the parent BasicModel.
"""
super().__init__(num, **kwargs)
# --- Model Parameters ---
[docs]
self.tau = tau # Synaptic time constant.
[docs]
self.k = k # Degree of the rescaled inhibition.
[docs]
self.a = a # Half-width of the range of excitatory connections.
[docs]
self.A = A # Magnitude of the external input.
[docs]
self.J0 = J0 # Maximum connection value (amplitude of the connectivity kernel).
# --- Feature Space Properties ---
[docs]
self.z_min = z_min # Minimum of the feature space.
[docs]
self.z_max = z_max # Maximum of the feature space.
[docs]
self.z_range = z_max - z_min # The total range of the feature space.
# An array representing the preferred feature value for each neuron.
[docs]
self.x = u.math.linspace(z_min, z_max, num)
[docs]
self.rho = num / self.z_range # The neural density.
[docs]
self.dx = self.z_range / num # The stimulus density
# --- Connectivity Matrix ---
# The connection matrix, defining the strength of synapses between all pairs of neurons.
[docs]
self.conn_mat = self.make_conn()
[docs]
def dist(self, d):
"""
Calculates the shortest distance between two points in a circular feature space
with periodic boundary conditions.
Args:
d (Array): The difference between two positions.
Returns:
Array: The shortest distance, wrapped around the periodic boundary.
"""
# Apply periodic boundary condition using the remainder.
d = u.math.remainder(d, self.z_range)
# Ensure the distance is the shortest path (e.g., the distance between 350 and 10 degrees is 20, not 340).
d = u.math.where(d > self.z_range / 2, d - self.z_range, d)
return d
[docs]
def make_conn(self):
"""
Constructs the connectivity matrix based on a Gaussian-like profile.
The connection strength between two neurons depends on the distance
between their preferred feature values in the circular space.
Returns:
Array: A (num x num) connectivity matrix.
"""
# Prepare coordinate arrays to compute pairwise distances.
x_left = u.math.reshape(self.x, (-1, 1))
x_right = u.math.repeat(self.x.reshape((1, -1)), len(self.x), axis=0)
# Calculate the pairwise distance matrix with periodic boundaries.
d = self.dist(x_left - x_right)
# Compute the connection strengths using a Gaussian (normal distribution) function.
# Neurons with similar feature preferences will have stronger excitatory connections.
return (
self.J0
* u.math.exp(-0.5 * u.math.square(d / self.a))
/ (u.math.sqrt(2 * u.math.pi) * self.a)
)
[docs]
def get_stimulus_by_pos(self, pos):
"""
Generates a Gaussian-shaped external stimulus centered at a given position.
Args:
pos (float): The center position of the stimulus in the feature space.
Returns:
Array: An array of stimulus values for each neuron.
"""
# The stimulus is a "bump" of activity, modeled by a Gaussian function.
return self.A * u.math.exp(-0.25 * u.math.square(self.dist(self.x - pos) / self.a))
[docs]
class CANN1D(BaseCANN1D):
"""
A standard 1D Continuous Attractor Neural Network (CANN) model.
This model implements the core dynamics where a localized "bump" of activity
can be sustained and moved by external inputs.
Reference:
Wu, S., Hamaguchi, K., & Amari, S. I. (2008). Dynamics and computation of continuous attractors.
Neural computation, 20(4), 994-1025.
"""
[docs]
def init_state(self, *args, **kwargs):
"""Initializes the state variables of the model."""
# --- State Variables ---
# Firing rate of the neurons.
self.r = brainstate.HiddenState(u.math.zeros(self.varshape))
# Synaptic input to the neurons.
self.u = brainstate.HiddenState(u.math.zeros(self.varshape))
# --- Inputs ---
# External input to the network.
self.inp = brainstate.State(u.math.zeros(self.varshape))
[docs]
def update(self, inp):
"""
The main update function, defining the dynamics of the network for one time step.
Args:
inp (Array): The external input for the current time step.
"""
self.inp.value = inp
# The numerator for the firing rate calculation (a non-linear activation function).
r1 = u.math.square(self.u.value)
# The denominator, which implements global divisive inhibition.
r2 = 1.0 + self.k * u.math.sum(r1)
# Calculate the firing rate of each neuron using divisive normalization.
self.r.value = r1 / r2
# Calculate the recurrent input from other neurons in the network.
Irec = u.math.dot(self.conn_mat, self.r.value)
# Update the synaptic inputs using Euler's method. The change depends on a leak
# current (-u), recurrent input (Irec), and external input (inp).
self.u.value += (
(-self.u.value + Irec + self.inp.value) / self.tau * brainstate.environ.get_dt()
)
[docs]
class CANN1D_SFA(BaseCANN1D):
"""
A 1D CANN model that incorporates Spike-Frequency Adaptation (SFA).
SFA is a slow negative feedback mechanism that causes neurons to fire less
over time for a sustained input, which can induce anticipative tracking behavior.
Reference:
Mi, Y., Fung, C. C., Wong, K. Y., & Wu, S. (2014). Spike frequency adaptation
implements anticipative tracking in continuous attractor neural networks.
Advances in neural information processing systems, 27.
"""
def __init__(
self,
num: int,
tau: time_type = 1.0,
tau_v: time_type = 50.0,
k: float = 8.1,
a: float = 0.3,
A: float = 0.2,
J0: float = 1.0,
z_min: float = -u.math.pi,
z_max: float = u.math.pi,
m: float = 0.3,
**kwargs,
):
"""
Initializes the 1D CANN model with SFA.
Args:
tau_v (float): The time constant for the adaptation variable 'v'. A larger value means slower adaptation.
m (float): The strength of the adaptation, coupling the membrane potential 'u' to the adaptation variable 'v'.
(Other parameters are inherited from BaseCANN1D)
"""
super().__init__(num, tau, k, a, A, J0, z_min, z_max, **kwargs)
# --- SFA-specific Parameters ---
[docs]
self.tau_v = tau_v # Time constant of the adaptation variable.
[docs]
self.m = m # Strength of the adaptation.
[docs]
def init_state(self, *args, **kwargs):
"""Initializes the state variables of the model, including the adaptation variable."""
# --- State Variables ---
self.r = brainstate.HiddenState(u.math.zeros(self.varshape)) # Firing rate.
self.u = brainstate.HiddenState(u.math.zeros(self.varshape)) # Synaptic inputs.
# self.v: The adaptation variable, which tracks the synaptic inputs 'u' and provides negative feedback.
self.v = brainstate.HiddenState(u.math.zeros(self.varshape))
# --- Inputs ---
self.inp = brainstate.State(u.math.zeros(self.varshape)) # External input.
[docs]
def update(self, inp):
"""
The main update function for the SFA model. It includes dynamics for both
the membrane potential and the adaptation variable.
Args:
inp (Array): The external input for the current time step.
"""
self.inp.value = inp
# Firing rate calculation is the same as the standard CANN model.
r1 = u.math.square(self.u.value)
r2 = 1.0 + self.k * u.math.sum(r1)
self.r.value = r1 / r2
# Calculate recurrent input.
Irec = u.math.dot(self.conn_mat, self.r.value)
# Update the synaptic input. Note the additional '- self.v.value' term,
self.u.value += (
(-self.u.value + Irec + self.inp.value - self.v.value)
/ self.tau
* brainstate.environ.get_dt()
)
# Update the adaptation variable 'v'. It slowly tracks the membrane potential 'u'
# and has its own decay, creating a slow negative feedback loop.
self.v.value += (
(-self.v.value + self.m * self.u.value) / self.tau_v * brainstate.environ.get_dt()
)
[docs]
class BaseCANN2D(BaseCANN):
"""
Base class for 2D Continuous Attractor Neural Network (CANN) models.
This class sets up the fundamental properties of the network, including
neuronal properties, feature space, and the connectivity matrix, which
are shared by different CANN model variations.
"""
def __init__(
self,
length: int,
tau: Quantity | float = 1.0,
k: float = 8.1,
a: float = 0.5,
A: float = 10,
J0: float = 4.0,
z_min: float = -u.math.pi,
z_max: float = u.math.pi,
**kwargs,
):
"""
Initializes the base 2D CANN model.
Args:
length (int): The number of neurons in one dimension of the network (the network is square).
tau (float): The synaptic time constant, controlling how quickly the membrane potential changes.
k (float): A parameter controlling the strength of the global inhibition.
a (float): The half-width of the excitatory connection range. It defines the "spread" of local connections.
A (float): The magnitude (amplitude) of the external stimulus.
J0 (float): The maximum connection strength between neurons.
z_min (float): The minimum value of the feature space (e.g., -pi for an angle).
z_max (float): The maximum value of the feature space (e.g., +pi for an angle).
**kwargs: Additional keyword arguments passed to the parent BasicModel.
"""
super().__init__((self.length,) * 2, **kwargs) # square network of neurons
# --- Model Parameters ---
[docs]
self.tau = tau # Synaptic time constant.
[docs]
self.k = k # Degree of the rescaled inhibition.
[docs]
self.a = a # Half-width of the range of excitatory connections.
[docs]
self.A = A # Magnitude of the external input.
[docs]
self.J0 = J0 # Maximum connection value (amplitude of the connectivity kernel).
# --- Feature Space Properties ---
[docs]
self.z_min = z_min # Minimum of the feature space.
[docs]
self.z_max = z_max # Maximum of the feature space.
[docs]
self.z_range = z_max - z_min # The total range of the feature space.
# An array representing the preferred feature value for each neuron.
[docs]
self.x = u.math.linspace(z_min, z_max, length)
[docs]
self.rho = length / self.z_range # The neural density.
[docs]
self.dx = self.z_range / length # The stimulus density
# --- Connectivity Matrix ---
# The connection matrix, defining the strength of synapses between all pairs of neurons.
[docs]
self.conn_mat = self.make_conn()
[docs]
def show_conn(self):
"""
Displays the connectivity matrix as an image.
This method visualizes the connection strengths between neurons in the 2D feature space.
"""
plt.imshow(u.math.as_numpy(self.conn_mat))
plt.colorbar()
plt.show()
[docs]
def dist(self, d):
"""
Calculates the shortest distance vector between two points in a 2D feature space
with periodic boundary conditions (a torus).
Args:
d (Array): The difference vector between two positions, e.g., [dx, dy].
Returns:
Array: The shortest distance vector, with each component wrapped around
the periodic boundary.
"""
# Define the size of the periodic box for each dimension.
box_size = u.math.asarray([self.z_range, self.z_range])
# Apply the periodic boundary condition to each component of the vector
# using the remainder. This wraps the differences into the [0, box_size) interval.
d = u.math.remainder(d, box_size)
# Ensure each component of the distance vector is the shortest path.
# For example, in a dimension of size 360, the distance between 350 and 10
# should be -20 (magnitude 20), not 340.
d = u.math.where(d > box_size / 2, d - box_size, d)
return d
[docs]
def make_conn(self):
"""
Constructs the connectivity matrix for a 2D grid of neurons based on a
Gaussian profile. The connection strength between two neurons depends on the
Euclidean distance between their preferred feature coordinates in a 2D
toroidal space (space with periodic boundaries in both dimensions).
Returns:
Array: A ((num*num) x (num*num)) connectivity matrix.
"""
# Create a 2D grid of coordinates for all neurons.
x1, x2 = u.math.meshgrid(self.x, self.x)
# Reshape the grid into a list of coordinate pairs.
all_coords = u.math.stack([x1.flatten(), x2.flatten()]).T
# Define a function to compute connectivity from one neuron to all others.
@jax.vmap
def get_conn_for_one_neuron(source_coord):
# Calculate the difference vectors from this source neuron to all other neurons.
# The self.dist function handles the periodic boundary conditions for each vector component.
diff_vectors = self.dist(source_coord - all_coords)
# Calculate the scalar Euclidean distance (L2 norm) for each difference vector.
# This gives the true shortest distance in the 2D toroidal space.
scalar_distances = u.math.linalg.norm(diff_vectors, axis=1)
# Compute connection strengths using the same Gaussian (normal distribution) function.
# Neurons with closer coordinates will have stronger excitatory connections.
conn_strengths = (
self.J0
* u.math.exp(-0.5 * u.math.square(scalar_distances / self.a))
/ (u.math.sqrt(2 * u.math.pi) * self.a)
)
return conn_strengths
return get_conn_for_one_neuron(all_coords)
[docs]
def get_stimulus_by_pos(self, pos):
"""
Generates a Gaussian-shaped external stimulus centered at a given
coordinate on the 2D neural grid.
Args:
pos (Array, tuple): The center coordinate [x, y] of the stimulus
in the feature space.
Returns:
Array: A 2D array (grid) of stimulus values for each neuron.
"""
# Validate that the input position is two-dimensional.
pos = u.math.asarray(pos)
assert pos.shape == (2,), "Input position must be a 2D coordinate, e.g., [x, y]."
# Create a 2D grid of coordinates for all neurons.
x1, x2 = u.math.meshgrid(self.x, self.x)
all_coords = u.math.stack([x1.flatten(), x2.flatten()]).T
# Calculate the distance from the stimulus center to every neuron.
diff_vectors = self.dist(all_coords - pos)
scalar_distances = u.math.linalg.norm(diff_vectors, axis=1)
# Calculate the stimulus intensity using a Gaussian function.
stimulus_flat = self.A * u.math.exp(-0.25 * u.math.square(scalar_distances / self.a))
# Reshape the flat stimulus array back into a 2D grid.
num_neurons_per_dim = self.x.shape[0]
return stimulus_flat.reshape((num_neurons_per_dim, num_neurons_per_dim))
[docs]
class CANN2D(BaseCANN2D):
"""
A 2D Continuous Attractor Neural Network (CANN) model.
This model extends the base CANN2D class to include specific dynamics
and properties for a 2D neural network.
Reference:
Wu, S., Hamaguchi, K., & Amari, S. I. (2008). Dynamics and computation of continuous attractors.
Neural computation, 20(4), 994-1025.
"""
[docs]
def init_state(self, *args, **kwargs):
"""
Initializes the state variables of the model.
"""
# --- State Variables ---
# Firing rate of the neurons.
self.r = brainstate.HiddenState(u.math.zeros((self.length, self.length)))
# Synaptic input to the neurons.
self.u = brainstate.HiddenState(u.math.zeros((self.length, self.length)))
# --- Inputs ---
# External input to the neurons
self.inp = brainstate.State(u.math.zeros((self.length, self.length)))
[docs]
def update(self, inp):
"""
The main update function, defining the dynamics of the network for one time step.
Args:
inp (Array): The external input to the network, which can be a stimulus or other driving force.
"""
self.inp.value = inp
# The numerator for the firing rate calculation (a non-linear activation function).
r1 = u.math.square(self.u.value)
# The denominator, which implements global divisive inhibition.
r2 = 1.0 + self.k * u.math.sum(r1)
# Calculate the firing rate of each neuron using divisive normalization.
self.r.value = r1 / r2
# Calculate the recurrent input from other neurons in the network.
Irec = (self.r.value.flatten() @ self.conn_mat).reshape((self.length, self.length))
# Update the synaptic input based on the recurrent input and external input.
self.u.value += (
(-self.u.value + Irec + self.inp.value) / self.tau * brainstate.environ.get_dt()
)
[docs]
class CANN2D_SFA(BaseCANN2D):
"""
A 2D Continuous Attractor Neural Network (CANN) model with a specific
implementation of the Synaptic Firing Activity (SFA) dynamics.
This model extends the base CANN2D class to include SFA-specific dynamics.
"""
def __init__(
self,
length: int,
tau: Quantity | float = 1.0,
tau_v: Quantity | float = 50.0,
k: float = 8.1,
a: float = 0.3,
A: float = 0.2,
J0: float = 1.0,
z_min: float = -u.math.pi,
z_max: float = u.math.pi,
m: float = 0.3,
**kwargs,
):
"""
Initializes the 2D CANN model with SFA dynamics.
"""
super().__init__(
length=length, tau=tau, k=k, a=a, A=A, J0=J0, z_min=z_min, z_max=z_max, **kwargs
)
# --- SFA-specific Parameters ---
[docs]
self.tau_v = tau_v # Time Constant of the adaptation variable.
[docs]
self.m = m # Strength of the adaptation.
[docs]
def init_state(self, *args, **kwargs):
"""Initializes the state variables of the model, including the adaptation variable."""
# --- State Variables ---
self.r = brainstate.HiddenState(u.math.zeros((self.length, self.length))) # Firing rate.
self.u = brainstate.HiddenState(u.math.zeros((self.length, self.length))) # Synaptic input.
# self.v: The adaptation variable, which tracks the synaptic inputs 'u' and provides negative feedback.
self.v = brainstate.HiddenState(u.math.zeros((self.length, self.length)))
# --- Inputs ---
self.inp = brainstate.State(u.math.zeros((self.length, self.length))) # External input.
[docs]
def update(self, inp):
"""
The main update function for the SFA model. It includes dynamics for both
the membrane potential and the adaptation variable.
Args:
inp (Array): The external input for the current time step.
"""
self.inp.value = inp
# Firing rate calculation is the same as the standard CANN model.
r1 = u.math.square(self.u.value)
r2 = 1.0 + self.k * u.math.sum(r1)
self.r.value = r1 / r2
# Calculate recurrent input.
Irec = (self.r.value.flatten() @ self.conn_mat).reshape((self.length, self.length))
# Update the synaptic input. Note the additional '- self.v.value' term,
self.u.value += (
(-self.u.value + Irec + self.inp.value - self.v.value)
/ self.tau
* brainstate.environ.get_dt()
)
# Update the adaptation variable 'v'. It slowly tracks the membrane potential 'u'
# and has its own decay, creating a slow negative feedback loop.
self.v.value += (
(-self.v.value + self.m * self.u.value) / self.tau_v * brainstate.environ.get_dt()
)