Source code for src.canns.task.tracking

from collections.abc import Sequence

import brainunit as u
import numpy as np
from tqdm import tqdm

from ..models.basic.cann import BaseCANN, BaseCANN1D, BaseCANN2D
from ..typing import Iext_type, time_type
from ._base import Task

__all__ = [
    "PopulationCoding1D",
    "TemplateMatching1D",
    "SmoothTracking1D",
    "PopulationCoding2D",
    "TemplateMatching2D",
    "SmoothTracking2D",
]


class TrackingTask(Task):
    """
    A task for simulating the tracking of external inputs in an n-D CANN.

    This class generates a complete time-series of external input stimuli based on
    a predefined sequence of input positions and their corresponding durations.
    It is designed to provide a consistent and repeatable input protocol for
    testing and analyzing Continuous Attractor Neural Network (CANN) models.

    The primary output is the `Iext_sequence` attribute, which is a NumPy array
    representing the input vector at each time step of the simulation.
    """

    def __init__(
        self,
        ndim: int,
        config: dict = None,
        **kwargs,
    ):
        """Initializes the tracking task and pre-computes the input sequence.

        This constructor sets up the simulation parameters and, most importantly,
        calls the internal `_make_Iext_sequence` method to generate the full
        stimulus protocol that will be used during the task run.

        Args:
            ndim (int): The dimensionality of the continuous attractor network space.
            config (dict, optional): A dictionary containing the parameters for the
                tracking simulation. Expected keys are:
                - Iext (Sequence[float]): A sequence of positions for the
                  external input stimulus.
                - duration (Sequence[float]): A sequence of durations, where each
                  duration corresponds to an input position in `Iext`.
                - time_step (float, optional): The simulation time step.
                  Defaults to 0.1.
                - cann_instance (BaseCANN): An instance of the CANN model
                    to be used for generating the stimulus patterns.
        """
        super().__init__()
        assert config is not None
        self.duration = config.get("duration", [])
        self.Iext = config.get("Iext", [])
        self.ndim = ndim

        # Simulation time control
        self.current_step = 0
        self.time_step = config.get("time_step", 0.1)
        self.total_duration = np.sum(self.duration)
        self.total_steps = np.ceil(self.total_duration / self.time_step).astype(dtype=int)

        self.run_steps = u.math.arange(0, self.total_duration, self.time_step)

        # checks
        if self.Iext is None or not isinstance(self.Iext, Sequence):
            raise ValueError("Configuration must include 'Iext' as a sequence of input positions.")
        if self.duration is None or not isinstance(self.duration, Sequence):
            raise ValueError("Configuration must include 'duration' as a sequence of time values.")

        # cann_instance
        cann_instance = config.get("cann_instance", None)
        if cann_instance is None or not isinstance(cann_instance, BaseCANN):
            raise ValueError(
                "Configuration must include 'cann_instance' as an instance of BaseCANN."
            )
        self.shape = cann_instance.shape
        self.get_stimulus_by_pos = cann_instance.get_stimulus_by_pos

    def _make_Iext_sequence(self):
        """
        Creates a time-series array of external input positions.
        This method generates a step-function sequence where each input `Iext[i]` is held constant
        for the corresponding `duration[i]`.

        Returns:
            Quantity or Array: An array representing the external input position at each time step.
        """
        Iext_sequence = np.zeros((self.total_steps, self.ndim), dtype=float)

        start_step = 0
        dur_steps = [int(dur / self.time_step) for dur in self.duration]
        for num_steps, iext_val in zip(dur_steps, self.Iext, strict=False):
            end_step = start_step + num_steps
            Iext_sequence[start_step:end_step, :] = iext_val
            start_step = end_step
        # If total duration is not perfectly divisible, fill the remainder with the last value.
        if start_step < self.total_steps:
            Iext_sequence[start_step:] = self.Iext[-1]
        return Iext_sequence

    def get_data(self, progress_bar: bool = True):
        """
        Generates the task data by creating a sequence of external inputs
        based on the provided `Iext` and `duration` parameters.
        """
        self.Iext_sequence = self._make_Iext_sequence()

        shape = (len(self.Iext_sequence), *self.shape)
        data = np.zeros(shape, dtype=float)

        for i, pos in tqdm(
            enumerate(self.Iext_sequence),
            desc=f"<{type(self).__name__}> Generating Task data",
            disable=not progress_bar,
        ):
            data[i] = self.get_stimulus_by_pos(pos)

        self.data = data

    def show_data(
        self,
        show=True,
        save_path=None,
    ):
        raise NotImplementedError(
            "The show_data method is not implemented for TrackingTask. "
            "Please implement this method in subclasses to visualize the task data."
        )


class PopulationCoding(TrackingTask):
    """
    Population coding task for n-D continuous attractor networks.
    In this task, a stimulus is presented for a specific duration, preceded and followed by
    periods of no stimulation, to test the network's ability to form and maintain a memory bump.
    """

    def __init__(
        self,
        cann_instance: BaseCANN,
        ndim: int,
        before_duration: time_type,
        after_duration: time_type,
        Iext: Iext_type,
        duration: time_type,
        time_step: time_type = 0.1,
    ):
        """
        Initializes the Population Coding task.

        Args:
            cann_instance (BaseCANN): An instance of the 1D CANN model.
            ndim (int): The dimensionality of the continuous attractor network.
            before_duration (float | Quantity): Duration of the initial period with no stimulus.
            after_duration (float | Quantity): Duration of the final period with no stimulus.
            Iext (float | Quantity): The position of the external input during the stimulation period.
            duration (float | Quantity): The duration of the stimulation period.
            time_step (float | Quantity, optional): The simulation time step. Defaults to 0.1.
        """
        # The task is structured as: no input -> input -> no input.
        # The base class handles this by taking sequences. Here, we provide dummy values for the
        # 'no input' periods, as the `update` method will handle turning off the input.
        super().__init__(
            ndim=ndim,
            config={
                "cann_instance": cann_instance,
                "Iext": (Iext, Iext, Iext),  # Repeated for before, during, and after phases.
                "duration": (before_duration, duration, after_duration),  # Duration for each phase.
                "time_step": time_step,  # Time step for the simulation.
            },
        )
        self.before_duration = before_duration
        self.after_duration = after_duration

    def get_data(self, progress_bar: bool = True):
        self.Iext_sequence = self._make_Iext_sequence()

        shape = (self.total_steps,) + self.shape
        data = np.zeros(shape, dtype=float)

        # Determine the time boundaries for applying the stimulus.
        start_time_step = int(self.before_duration / self.time_step)
        end_time_step = int((self.total_duration - self.after_duration) / self.time_step)
        stimulus = self.get_stimulus_by_pos(self.Iext_sequence[start_time_step])

        # for i in tqdm(
        #     range(start_time_step, end_time_step),
        #     desc=f"<{type(self).__name__}>Generating Task data",
        #     disable=not progress_bar
        # ):
        if progress_bar:
            print(f"<{type(self).__name__}>Generating Task data(No For Loop)")
        data[start_time_step:end_time_step] = stimulus

        self.data = data


class TemplateMatching(TrackingTask):
    """
    Template matching task for n-D continuous attractor networks.
    This task presents a stimulus with added noise to test the network's ability to
    denoise the input and settle on the correct underlying pattern (template).
    """

    def __init__(
        self,
        cann_instance: BaseCANN,
        ndim: int,
        Iext: Iext_type,
        duration: time_type,
        time_step: time_type = 0.1,
    ):
        """
        Initializes the Template Matching task.

        Args:
            cann_instance (BaseCANN): An instance of the 1D CANN model.
            ndim (int): The dimensionality of the continuous attractor network.
            Iext (float | Quantity): The position of the external input.
            duration (float | Quantity): The duration for which the noisy stimulus is presented.
            time_step (float | Quantity, optional): The simulation time step. Defaults to 0.1.
        """
        super().__init__(
            ndim=ndim,
            config={
                "cann_instance": cann_instance,
                "Iext": (Iext,),  # Single input position for the template matching task.
                "duration": (duration,),  # Single duration for the stimulus.
                "time_step": time_step,  # Time step for the simulation.
            },
        )
        self.A = cann_instance.A  # The amplitude of the noise to be added.

    def get_data(self, progress_bar: bool = True):
        self.Iext_sequence = self._make_Iext_sequence()

        shape = (self.total_steps,) + self.shape
        data = np.zeros(shape, dtype=float)

        # Generate the stimulus pattern for the given input position.
        stimulus = self.get_stimulus_by_pos(self.Iext_sequence[0])

        # Add noise to the stimulus for each time step.
        for i in tqdm(
            range(self.total_steps),
            desc=f"<{type(self).__name__}>Generating Task data",
            disable=not progress_bar,
        ):
            noise = 0.1 * self.A * np.random.randn(*self.shape)
            data[i] = stimulus + noise

        self.data = data


class SmoothTracking(TrackingTask):
    """
    Smooth tracking task for n-D continuous attractor networks.
    This task provides an external input that moves smoothly over time, testing the network's
    ability to track a continuously changing stimulus.
    """

    def __init__(
        self,
        cann_instance: BaseCANN,
        ndim: int,
        Iext: Sequence[Iext_type],
        duration: Sequence[time_type],
        time_step: time_type = 0.1,
    ):
        """
        Initializes the Smooth Tracking task.

        Args:
            cann_instance (BaseCANN): An instance of the 1D CANN model.
            Iext (Sequence[float | Quantity]): A sequence of keypoint positions for the input.
            duration (Sequence[float | Quantity]): The duration of each segment of smooth movement.
            time_step (float | Quantity, optional): The simulation time step. Defaults to 0.1.
        """
        assert len(tuple(Iext)) == (len(tuple(duration)) + 1), (
            "Iext must have one more element than duration to define start and end points for each segment."
        )
        super().__init__(
            ndim=ndim,
            config={
                "cann_instance": cann_instance,
                "Iext": Iext,  # Sequence of keypoint positions for the input.
                "duration": duration,  # Sequence of durations for each segment.
                "time_step": time_step,  # Time step for the simulation.
            },
        )

    def _make_Iext_sequence(self):
        """
        Creates a time-series of external input positions that smoothly transitions
        between the keypoints defined in `self.Iext`.
        The output is an array of shape (total_steps, ndim).
        """
        # The output sequence now has a shape of (total_steps, ndim) to hold coordinates.
        Iext_sequence = np.zeros((self.total_steps, self.ndim), dtype=float)
        start_step = 0

        if self.ndim == 1:
            for i, dur in enumerate(self.duration):
                num_steps = int(dur / self.time_step)
                if num_steps == 0:
                    continue
                end_step = start_step + num_steps
                Iext_sequence[start_step:end_step] = np.linspace(
                    self.Iext[i], self.Iext[i + 1], num_steps
                ).reshape(-1, 1)
                start_step = end_step
            if start_step < self.total_steps:
                Iext_sequence[start_step:] = self.Iext[-1]
        else:
            for i, dur in enumerate(self.duration):
                num_steps = int(dur / self.time_step)
                if num_steps == 0:
                    continue
                end_step = start_step + num_steps

                # Define start and end points (which are now tuples/vectors) for interpolation
                start_pos = self.Iext[i]
                end_pos = self.Iext[i + 1]

                # Interpolate each dimension independently
                for d in range(self.ndim):
                    start_d = start_pos[d]
                    end_d = end_pos[d]
                    Iext_sequence[start_step:end_step, d] = np.linspace(start_d, end_d, num_steps)

                start_step = end_step

            # Fill any remaining steps with the final position.
            if start_step < self.total_steps:
                # self.Iext[-1] is a tuple of shape (ndim,), which will be broadcast correctly.
                Iext_sequence[start_step:, :] = self.Iext[-1]

        return Iext_sequence


[docs] class PopulationCoding1D(PopulationCoding): """ Population coding task for 1D continuous attractor networks. In this task, a stimulus is presented for a specific duration, preceded and followed by periods of no stimulation, to test the network's ability to form and maintain a memory bump. """ def __init__( self, cann_instance: BaseCANN1D, before_duration: time_type, after_duration: time_type, Iext: Iext_type, duration: time_type, time_step: time_type = 0.1, ): """ Initializes the Population Coding task. Args: cann_instance (BaseCANN1D): An instance of the 1D CANN model. before_duration (float | Quantity): Duration of the initial period with no stimulus. after_duration (float | Quantity): Duration of the final period with no stimulus. Iext (float | Quantity): The position of the external input during the stimulation period. duration (float | Quantity): The duration of the stimulation period. time_step (float | Quantity, optional): The simulation time step. Defaults to 0.1. """ # The task is structured as: no input -> input -> no input. # The base class handles this by taking sequences. Here, we provide dummy values for the # 'no input' periods, as the `update` method will handle turning off the input. super().__init__( cann_instance=cann_instance, ndim=1, before_duration=before_duration, after_duration=after_duration, Iext=Iext, duration=duration, time_step=time_step, )
[docs] self.before_duration = before_duration
[docs] self.after_duration = after_duration
[docs] class TemplateMatching1D(TemplateMatching): """ Template matching task for 1D continuous attractor networks. This task presents a stimulus with added noise to test the network's ability to denoise the input and settle on the correct underlying pattern (template). """ def __init__( self, cann_instance: BaseCANN1D, Iext: Iext_type, duration: time_type, time_step: time_type = 0.1, ): """ Initializes the Template Matching task. Args: cann_instance (BaseCANN1D): An instance of the 1D CANN model. Iext (float | Quantity): The position of the external input. duration (float | Quantity): The duration for which the noisy stimulus is presented. time_step (float | Quantity, optional): The simulation time step. Defaults to 0.1. """ super().__init__( cann_instance=cann_instance, ndim=1, Iext=Iext, duration=duration, time_step=time_step, )
[docs] class SmoothTracking1D(SmoothTracking): """ Smooth tracking task for 1D continuous attractor networks. This task provides an external input that moves smoothly over time, testing the network's ability to track a continuously changing stimulus. """ def __init__( self, cann_instance: BaseCANN1D, Iext: Sequence[Iext_type], duration: Sequence[time_type], time_step: time_type = 0.1, ): """ Initializes the Smooth Tracking task. Args: cann_instance (BaseCANN1D): An instance of the 1D CANN model. Iext (Sequence[float | Quantity]): A sequence of keypoint positions for the input. duration (Sequence[float | Quantity]): The duration of each segment of smooth movement. time_step (float | Quantity, optional): The simulation time step. Defaults to 0.1. """ super().__init__( cann_instance=cann_instance, ndim=1, Iext=Iext, duration=duration, time_step=time_step, )
class CustomTracking1D(TrackingTask): """ A template class for creating custom 1D tracking tasks. Users should inherit from this class and implement their own logic for `_make_Iext_sequence` and/or `update` to define a new task. """ def __init__(self, *args, **kwargs): """Initializes the custom task using the base class constructor.""" super().__init__(*args, ndim=1, **kwargs) def _make_Iext_sequence(self): """ Placeholder for custom input sequence generation. This method should be overridden to create a specific time-series of inputs. """ # Example: raise an error to enforce implementation by subclasses. raise NotImplementedError("Please implement _make_Iext_sequence for your custom task.") def update(self): """ Placeholder for custom update logic. This method can be overridden to introduce custom behavior at each time step, such as adding specific types of noise or conditional stimuli. """ # Example: raise an error to enforce implementation by subclasses. raise NotImplementedError("Please implement the update logic for your custom task.")
[docs] class PopulationCoding2D(PopulationCoding): """ Population coding task for 2D continuous attractor networks. In this task, a stimulus is presented for a specific duration, preceded and followed by periods of no stimulation, to test the network's ability to form and maintain a memory bump. """ def __init__( self, cann_instance: BaseCANN2D, before_duration: time_type, after_duration: time_type, Iext: Iext_type, duration: time_type, time_step: time_type = 0.1, ): """ Initializes the Population Coding task. Args: cann_instance (BaseCANN2D): An instance of the 2D CANN model. before_duration (float | Quantity): Duration of the initial period with no stimulus. after_duration (float | Quantity): Duration of the final period with no stimulus. Iext (float | Quantity): The position of the external input during the stimulation period. duration (float | Quantity): The duration of the stimulation period. time_step (float | Quantity, optional): The simulation time step. Defaults to 0.1. """ # The task is structured as: no input -> input -> no input. # The base class handles this by taking sequences. Here, we provide dummy values for the # 'no input' periods, as the `update` method will handle turning off the input. assert len(Iext) == 2, "Iext must be a tuple of two values for 2D tracking." super().__init__( cann_instance=cann_instance, ndim=2, before_duration=before_duration, after_duration=after_duration, Iext=Iext, duration=duration, time_step=time_step, )
[docs] self.before_duration = before_duration
[docs] self.after_duration = after_duration
[docs] class TemplateMatching2D(TemplateMatching): """ Template matching task for 2D continuous attractor networks. This task presents a stimulus with added noise to test the network's ability to denoise the input and settle on the correct underlying pattern (template). """ def __init__( self, cann_instance: BaseCANN2D, Iext: Iext_type, duration: time_type, time_step: time_type = 0.1, ): """ Initializes the Template Matching task. Args: cann_instance (BaseCANN2D): An instance of the 2D CANN model. Iext (float | Quantity): The position of the external input. duration (float | Quantity): The duration for which the noisy stimulus is presented. time_step (float | Quantity, optional): The simulation time step. Defaults to 0.1. """ assert len(Iext) == 2, "Iext must be a tuple of two values for 2D tracking." super().__init__( cann_instance=cann_instance, ndim=2, Iext=Iext, duration=duration, time_step=time_step, )
[docs] class SmoothTracking2D(SmoothTracking): """ Smooth tracking task for 2D continuous attractor networks. This task provides an external input that moves smoothly over time, testing the network's ability to track a continuously changing stimulus. """ def __init__( self, cann_instance: BaseCANN2D, Iext: Sequence[Iext_type], duration: Sequence[time_type], time_step: time_type = 0.1, ): """ Initializes the Smooth Tracking task. Args: cann_instance (BaseCANN2D): An instance of the 2D CANN model. Iext (Sequence[float | Quantity]): A sequence of keypoint positions for the input. duration (Sequence[float | Quantity]): The duration of each segment of smooth movement. time_step (float | Quantity, optional): The simulation time step. Defaults to 0.1. """ super().__init__( cann_instance=cann_instance, ndim=2, Iext=Iext, duration=duration, time_step=time_step, )
class CustomTracking2D(TrackingTask): """ A template class for creating custom 2D tracking tasks. Users should inherit from this class and implement their own logic for `_make_Iext_sequence` and/or `update` to define a new task. """ def __init__(self, *args, **kwargs): """Initializes the custom task using the base class constructor.""" super().__init__(*args, ndim=2, **kwargs) def _make_Iext_sequence(self): """ Placeholder for custom input sequence generation. This method should be overridden to create a specific time-series of inputs. """ # Example: raise an error to enforce implementation by subclasses. raise NotImplementedError("Please implement _make_Iext_sequence for your custom task.") def update(self): """ Placeholder for custom update logic. This method can be overridden to introduce custom behavior at each time step, such as adding specific types of noise or conditional stimuli. """ # Example: raise an error to enforce implementation by subclasses. raise NotImplementedError("Please implement the update logic for your custom task.")