from collections.abc import Sequence
import brainunit as u
import numpy as np
from tqdm import tqdm
from ..models.basic.cann import BaseCANN, BaseCANN1D, BaseCANN2D
from ..typing import Iext_type, time_type
from ._base import Task
__all__ = [
"PopulationCoding1D",
"TemplateMatching1D",
"SmoothTracking1D",
"PopulationCoding2D",
"TemplateMatching2D",
"SmoothTracking2D",
]
class TrackingTask(Task):
"""
A task for simulating the tracking of external inputs in an n-D CANN.
This class generates a complete time-series of external input stimuli based on
a predefined sequence of input positions and their corresponding durations.
It is designed to provide a consistent and repeatable input protocol for
testing and analyzing Continuous Attractor Neural Network (CANN) models.
The primary output is the `Iext_sequence` attribute, which is a NumPy array
representing the input vector at each time step of the simulation.
"""
def __init__(
self,
ndim: int,
config: dict = None,
**kwargs,
):
"""Initializes the tracking task and pre-computes the input sequence.
This constructor sets up the simulation parameters and, most importantly,
calls the internal `_make_Iext_sequence` method to generate the full
stimulus protocol that will be used during the task run.
Args:
ndim (int): The dimensionality of the continuous attractor network space.
config (dict, optional): A dictionary containing the parameters for the
tracking simulation. Expected keys are:
- Iext (Sequence[float]): A sequence of positions for the
external input stimulus.
- duration (Sequence[float]): A sequence of durations, where each
duration corresponds to an input position in `Iext`.
- time_step (float, optional): The simulation time step.
Defaults to 0.1.
- cann_instance (BaseCANN): An instance of the CANN model
to be used for generating the stimulus patterns.
"""
super().__init__()
assert config is not None
self.duration = config.get("duration", [])
self.Iext = config.get("Iext", [])
self.ndim = ndim
# Simulation time control
self.current_step = 0
self.time_step = config.get("time_step", 0.1)
self.total_duration = np.sum(self.duration)
self.total_steps = np.ceil(self.total_duration / self.time_step).astype(dtype=int)
self.run_steps = u.math.arange(0, self.total_duration, self.time_step)
# checks
if self.Iext is None or not isinstance(self.Iext, Sequence):
raise ValueError("Configuration must include 'Iext' as a sequence of input positions.")
if self.duration is None or not isinstance(self.duration, Sequence):
raise ValueError("Configuration must include 'duration' as a sequence of time values.")
# cann_instance
cann_instance = config.get("cann_instance", None)
if cann_instance is None or not isinstance(cann_instance, BaseCANN):
raise ValueError(
"Configuration must include 'cann_instance' as an instance of BaseCANN."
)
self.shape = cann_instance.shape
self.get_stimulus_by_pos = cann_instance.get_stimulus_by_pos
def _make_Iext_sequence(self):
"""
Creates a time-series array of external input positions.
This method generates a step-function sequence where each input `Iext[i]` is held constant
for the corresponding `duration[i]`.
Returns:
Quantity or Array: An array representing the external input position at each time step.
"""
Iext_sequence = np.zeros((self.total_steps, self.ndim), dtype=float)
start_step = 0
dur_steps = [int(dur / self.time_step) for dur in self.duration]
for num_steps, iext_val in zip(dur_steps, self.Iext, strict=False):
end_step = start_step + num_steps
Iext_sequence[start_step:end_step, :] = iext_val
start_step = end_step
# If total duration is not perfectly divisible, fill the remainder with the last value.
if start_step < self.total_steps:
Iext_sequence[start_step:] = self.Iext[-1]
return Iext_sequence
def get_data(self, progress_bar: bool = True):
"""
Generates the task data by creating a sequence of external inputs
based on the provided `Iext` and `duration` parameters.
"""
self.Iext_sequence = self._make_Iext_sequence()
shape = (len(self.Iext_sequence), *self.shape)
data = np.zeros(shape, dtype=float)
for i, pos in tqdm(
enumerate(self.Iext_sequence),
desc=f"<{type(self).__name__}> Generating Task data",
disable=not progress_bar,
):
data[i] = self.get_stimulus_by_pos(pos)
self.data = data
def show_data(
self,
show=True,
save_path=None,
):
raise NotImplementedError(
"The show_data method is not implemented for TrackingTask. "
"Please implement this method in subclasses to visualize the task data."
)
class PopulationCoding(TrackingTask):
"""
Population coding task for n-D continuous attractor networks.
In this task, a stimulus is presented for a specific duration, preceded and followed by
periods of no stimulation, to test the network's ability to form and maintain a memory bump.
"""
def __init__(
self,
cann_instance: BaseCANN,
ndim: int,
before_duration: time_type,
after_duration: time_type,
Iext: Iext_type,
duration: time_type,
time_step: time_type = 0.1,
):
"""
Initializes the Population Coding task.
Args:
cann_instance (BaseCANN): An instance of the 1D CANN model.
ndim (int): The dimensionality of the continuous attractor network.
before_duration (float | Quantity): Duration of the initial period with no stimulus.
after_duration (float | Quantity): Duration of the final period with no stimulus.
Iext (float | Quantity): The position of the external input during the stimulation period.
duration (float | Quantity): The duration of the stimulation period.
time_step (float | Quantity, optional): The simulation time step. Defaults to 0.1.
"""
# The task is structured as: no input -> input -> no input.
# The base class handles this by taking sequences. Here, we provide dummy values for the
# 'no input' periods, as the `update` method will handle turning off the input.
super().__init__(
ndim=ndim,
config={
"cann_instance": cann_instance,
"Iext": (Iext, Iext, Iext), # Repeated for before, during, and after phases.
"duration": (before_duration, duration, after_duration), # Duration for each phase.
"time_step": time_step, # Time step for the simulation.
},
)
self.before_duration = before_duration
self.after_duration = after_duration
def get_data(self, progress_bar: bool = True):
self.Iext_sequence = self._make_Iext_sequence()
shape = (self.total_steps,) + self.shape
data = np.zeros(shape, dtype=float)
# Determine the time boundaries for applying the stimulus.
start_time_step = int(self.before_duration / self.time_step)
end_time_step = int((self.total_duration - self.after_duration) / self.time_step)
stimulus = self.get_stimulus_by_pos(self.Iext_sequence[start_time_step])
# for i in tqdm(
# range(start_time_step, end_time_step),
# desc=f"<{type(self).__name__}>Generating Task data",
# disable=not progress_bar
# ):
if progress_bar:
print(f"<{type(self).__name__}>Generating Task data(No For Loop)")
data[start_time_step:end_time_step] = stimulus
self.data = data
class TemplateMatching(TrackingTask):
"""
Template matching task for n-D continuous attractor networks.
This task presents a stimulus with added noise to test the network's ability to
denoise the input and settle on the correct underlying pattern (template).
"""
def __init__(
self,
cann_instance: BaseCANN,
ndim: int,
Iext: Iext_type,
duration: time_type,
time_step: time_type = 0.1,
):
"""
Initializes the Template Matching task.
Args:
cann_instance (BaseCANN): An instance of the 1D CANN model.
ndim (int): The dimensionality of the continuous attractor network.
Iext (float | Quantity): The position of the external input.
duration (float | Quantity): The duration for which the noisy stimulus is presented.
time_step (float | Quantity, optional): The simulation time step. Defaults to 0.1.
"""
super().__init__(
ndim=ndim,
config={
"cann_instance": cann_instance,
"Iext": (Iext,), # Single input position for the template matching task.
"duration": (duration,), # Single duration for the stimulus.
"time_step": time_step, # Time step for the simulation.
},
)
self.A = cann_instance.A # The amplitude of the noise to be added.
def get_data(self, progress_bar: bool = True):
self.Iext_sequence = self._make_Iext_sequence()
shape = (self.total_steps,) + self.shape
data = np.zeros(shape, dtype=float)
# Generate the stimulus pattern for the given input position.
stimulus = self.get_stimulus_by_pos(self.Iext_sequence[0])
# Add noise to the stimulus for each time step.
for i in tqdm(
range(self.total_steps),
desc=f"<{type(self).__name__}>Generating Task data",
disable=not progress_bar,
):
noise = 0.1 * self.A * np.random.randn(*self.shape)
data[i] = stimulus + noise
self.data = data
class SmoothTracking(TrackingTask):
"""
Smooth tracking task for n-D continuous attractor networks.
This task provides an external input that moves smoothly over time, testing the network's
ability to track a continuously changing stimulus.
"""
def __init__(
self,
cann_instance: BaseCANN,
ndim: int,
Iext: Sequence[Iext_type],
duration: Sequence[time_type],
time_step: time_type = 0.1,
):
"""
Initializes the Smooth Tracking task.
Args:
cann_instance (BaseCANN): An instance of the 1D CANN model.
Iext (Sequence[float | Quantity]): A sequence of keypoint positions for the input.
duration (Sequence[float | Quantity]): The duration of each segment of smooth movement.
time_step (float | Quantity, optional): The simulation time step. Defaults to 0.1.
"""
assert len(tuple(Iext)) == (len(tuple(duration)) + 1), (
"Iext must have one more element than duration to define start and end points for each segment."
)
super().__init__(
ndim=ndim,
config={
"cann_instance": cann_instance,
"Iext": Iext, # Sequence of keypoint positions for the input.
"duration": duration, # Sequence of durations for each segment.
"time_step": time_step, # Time step for the simulation.
},
)
def _make_Iext_sequence(self):
"""
Creates a time-series of external input positions that smoothly transitions
between the keypoints defined in `self.Iext`.
The output is an array of shape (total_steps, ndim).
"""
# The output sequence now has a shape of (total_steps, ndim) to hold coordinates.
Iext_sequence = np.zeros((self.total_steps, self.ndim), dtype=float)
start_step = 0
if self.ndim == 1:
for i, dur in enumerate(self.duration):
num_steps = int(dur / self.time_step)
if num_steps == 0:
continue
end_step = start_step + num_steps
Iext_sequence[start_step:end_step] = np.linspace(
self.Iext[i], self.Iext[i + 1], num_steps
).reshape(-1, 1)
start_step = end_step
if start_step < self.total_steps:
Iext_sequence[start_step:] = self.Iext[-1]
else:
for i, dur in enumerate(self.duration):
num_steps = int(dur / self.time_step)
if num_steps == 0:
continue
end_step = start_step + num_steps
# Define start and end points (which are now tuples/vectors) for interpolation
start_pos = self.Iext[i]
end_pos = self.Iext[i + 1]
# Interpolate each dimension independently
for d in range(self.ndim):
start_d = start_pos[d]
end_d = end_pos[d]
Iext_sequence[start_step:end_step, d] = np.linspace(start_d, end_d, num_steps)
start_step = end_step
# Fill any remaining steps with the final position.
if start_step < self.total_steps:
# self.Iext[-1] is a tuple of shape (ndim,), which will be broadcast correctly.
Iext_sequence[start_step:, :] = self.Iext[-1]
return Iext_sequence
[docs]
class PopulationCoding1D(PopulationCoding):
"""
Population coding task for 1D continuous attractor networks.
In this task, a stimulus is presented for a specific duration, preceded and followed by
periods of no stimulation, to test the network's ability to form and maintain a memory bump.
"""
def __init__(
self,
cann_instance: BaseCANN1D,
before_duration: time_type,
after_duration: time_type,
Iext: Iext_type,
duration: time_type,
time_step: time_type = 0.1,
):
"""
Initializes the Population Coding task.
Args:
cann_instance (BaseCANN1D): An instance of the 1D CANN model.
before_duration (float | Quantity): Duration of the initial period with no stimulus.
after_duration (float | Quantity): Duration of the final period with no stimulus.
Iext (float | Quantity): The position of the external input during the stimulation period.
duration (float | Quantity): The duration of the stimulation period.
time_step (float | Quantity, optional): The simulation time step. Defaults to 0.1.
"""
# The task is structured as: no input -> input -> no input.
# The base class handles this by taking sequences. Here, we provide dummy values for the
# 'no input' periods, as the `update` method will handle turning off the input.
super().__init__(
cann_instance=cann_instance,
ndim=1,
before_duration=before_duration,
after_duration=after_duration,
Iext=Iext,
duration=duration,
time_step=time_step,
)
[docs]
self.before_duration = before_duration
[docs]
self.after_duration = after_duration
[docs]
class TemplateMatching1D(TemplateMatching):
"""
Template matching task for 1D continuous attractor networks.
This task presents a stimulus with added noise to test the network's ability to
denoise the input and settle on the correct underlying pattern (template).
"""
def __init__(
self,
cann_instance: BaseCANN1D,
Iext: Iext_type,
duration: time_type,
time_step: time_type = 0.1,
):
"""
Initializes the Template Matching task.
Args:
cann_instance (BaseCANN1D): An instance of the 1D CANN model.
Iext (float | Quantity): The position of the external input.
duration (float | Quantity): The duration for which the noisy stimulus is presented.
time_step (float | Quantity, optional): The simulation time step. Defaults to 0.1.
"""
super().__init__(
cann_instance=cann_instance,
ndim=1,
Iext=Iext,
duration=duration,
time_step=time_step,
)
[docs]
class SmoothTracking1D(SmoothTracking):
"""
Smooth tracking task for 1D continuous attractor networks.
This task provides an external input that moves smoothly over time, testing the network's
ability to track a continuously changing stimulus.
"""
def __init__(
self,
cann_instance: BaseCANN1D,
Iext: Sequence[Iext_type],
duration: Sequence[time_type],
time_step: time_type = 0.1,
):
"""
Initializes the Smooth Tracking task.
Args:
cann_instance (BaseCANN1D): An instance of the 1D CANN model.
Iext (Sequence[float | Quantity]): A sequence of keypoint positions for the input.
duration (Sequence[float | Quantity]): The duration of each segment of smooth movement.
time_step (float | Quantity, optional): The simulation time step. Defaults to 0.1.
"""
super().__init__(
cann_instance=cann_instance,
ndim=1,
Iext=Iext,
duration=duration,
time_step=time_step,
)
class CustomTracking1D(TrackingTask):
"""
A template class for creating custom 1D tracking tasks.
Users should inherit from this class and implement their own logic for
`_make_Iext_sequence` and/or `update` to define a new task.
"""
def __init__(self, *args, **kwargs):
"""Initializes the custom task using the base class constructor."""
super().__init__(*args, ndim=1, **kwargs)
def _make_Iext_sequence(self):
"""
Placeholder for custom input sequence generation.
This method should be overridden to create a specific time-series of inputs.
"""
# Example: raise an error to enforce implementation by subclasses.
raise NotImplementedError("Please implement _make_Iext_sequence for your custom task.")
def update(self):
"""
Placeholder for custom update logic.
This method can be overridden to introduce custom behavior at each time step,
such as adding specific types of noise or conditional stimuli.
"""
# Example: raise an error to enforce implementation by subclasses.
raise NotImplementedError("Please implement the update logic for your custom task.")
[docs]
class PopulationCoding2D(PopulationCoding):
"""
Population coding task for 2D continuous attractor networks.
In this task, a stimulus is presented for a specific duration, preceded and followed by
periods of no stimulation, to test the network's ability to form and maintain a memory bump.
"""
def __init__(
self,
cann_instance: BaseCANN2D,
before_duration: time_type,
after_duration: time_type,
Iext: Iext_type,
duration: time_type,
time_step: time_type = 0.1,
):
"""
Initializes the Population Coding task.
Args:
cann_instance (BaseCANN2D): An instance of the 2D CANN model.
before_duration (float | Quantity): Duration of the initial period with no stimulus.
after_duration (float | Quantity): Duration of the final period with no stimulus.
Iext (float | Quantity): The position of the external input during the stimulation period.
duration (float | Quantity): The duration of the stimulation period.
time_step (float | Quantity, optional): The simulation time step. Defaults to 0.1.
"""
# The task is structured as: no input -> input -> no input.
# The base class handles this by taking sequences. Here, we provide dummy values for the
# 'no input' periods, as the `update` method will handle turning off the input.
assert len(Iext) == 2, "Iext must be a tuple of two values for 2D tracking."
super().__init__(
cann_instance=cann_instance,
ndim=2,
before_duration=before_duration,
after_duration=after_duration,
Iext=Iext,
duration=duration,
time_step=time_step,
)
[docs]
self.before_duration = before_duration
[docs]
self.after_duration = after_duration
[docs]
class TemplateMatching2D(TemplateMatching):
"""
Template matching task for 2D continuous attractor networks.
This task presents a stimulus with added noise to test the network's ability to
denoise the input and settle on the correct underlying pattern (template).
"""
def __init__(
self,
cann_instance: BaseCANN2D,
Iext: Iext_type,
duration: time_type,
time_step: time_type = 0.1,
):
"""
Initializes the Template Matching task.
Args:
cann_instance (BaseCANN2D): An instance of the 2D CANN model.
Iext (float | Quantity): The position of the external input.
duration (float | Quantity): The duration for which the noisy stimulus is presented.
time_step (float | Quantity, optional): The simulation time step. Defaults to 0.1.
"""
assert len(Iext) == 2, "Iext must be a tuple of two values for 2D tracking."
super().__init__(
cann_instance=cann_instance,
ndim=2,
Iext=Iext,
duration=duration,
time_step=time_step,
)
[docs]
class SmoothTracking2D(SmoothTracking):
"""
Smooth tracking task for 2D continuous attractor networks.
This task provides an external input that moves smoothly over time, testing the network's
ability to track a continuously changing stimulus.
"""
def __init__(
self,
cann_instance: BaseCANN2D,
Iext: Sequence[Iext_type],
duration: Sequence[time_type],
time_step: time_type = 0.1,
):
"""
Initializes the Smooth Tracking task.
Args:
cann_instance (BaseCANN2D): An instance of the 2D CANN model.
Iext (Sequence[float | Quantity]): A sequence of keypoint positions for the input.
duration (Sequence[float | Quantity]): The duration of each segment of smooth movement.
time_step (float | Quantity, optional): The simulation time step. Defaults to 0.1.
"""
super().__init__(
cann_instance=cann_instance,
ndim=2,
Iext=Iext,
duration=duration,
time_step=time_step,
)
class CustomTracking2D(TrackingTask):
"""
A template class for creating custom 2D tracking tasks.
Users should inherit from this class and implement their own logic for
`_make_Iext_sequence` and/or `update` to define a new task.
"""
def __init__(self, *args, **kwargs):
"""Initializes the custom task using the base class constructor."""
super().__init__(*args, ndim=2, **kwargs)
def _make_Iext_sequence(self):
"""
Placeholder for custom input sequence generation.
This method should be overridden to create a specific time-series of inputs.
"""
# Example: raise an error to enforce implementation by subclasses.
raise NotImplementedError("Please implement _make_Iext_sequence for your custom task.")
def update(self):
"""
Placeholder for custom update logic.
This method can be overridden to introduce custom behavior at each time step,
such as adding specific types of noise or conditional stimuli.
"""
# Example: raise an error to enforce implementation by subclasses.
raise NotImplementedError("Please implement the update logic for your custom task.")