Source code for axon_sdk.primitives.networks
"""
Spiking Network Composition
===========================
This module defines the `SpikingNetworkModule` class, a hierarchical container for building composable STICK-based
spiking networks with neuron and subnetwork modularity.
Key components:
- `SpikingNetworkModule`: Base class for defining networks with neurons and submodules.
- `flatten_nested_list`: Utility to flatten arbitrarily nested lists.
"""
from .elements import Synapse, ExplicitNeuron
from typing import Optional, Self
[docs]
def flatten_nested_list(nested_list: list) -> list:
"""
Recursively flattens an arbitrarily nested list into a single list.
Args:
nested_list (list): A list which may contain other lists as elements.
Returns:
list: A flat list containing all elements in order.
"""
flat_list = []
for item in nested_list:
if isinstance(item, list):
flat_list.extend(flatten_nested_list(item))
else:
flat_list.append(item)
return flat_list
[docs]
class SpikingNetworkModule:
"""
Base class for constructing hierarchical spiking networks in the STICK model.
Each module can contain neurons and nested subnetworks, enabling compositional
construction of larger networks.
Attributes:
_neurons (list[ExplicitNeuron]): List of neurons directly in this module.
_subnetworks (list[SpikingNetworkModule]): Nested submodules.
_uid (str): Globally unique identifier for this module.
_instance_count (int): Internal instance index.
"""
_global_instance_count = 0
def __init__(self, module_name: Optional[str] = None) -> None:
"""
Initialize a new spiking network module.
Args:
module_name (str, optional): Optional name for this module used in its UID.
"""
self._neurons: list[ExplicitNeuron] = []
self._subnetworks: list[Self] = []
self._instance_count = SpikingNetworkModule._global_instance_count
if module_name:
self._uid = f"(m{self.instance_count})_{module_name}"
else:
self._uid = f"(m{self.instance_count})"
SpikingNetworkModule._global_instance_count += 1
@property
def uid(self) -> str:
"""
Returns:
str: Unique identifier of this module.
"""
return self._uid
@property
def neurons(self) -> list[ExplicitNeuron]:
"""
Recursively collect all neurons from this module and its submodules.
Returns:
list[ExplicitNeuron]: List of all neurons in the hierarchy.
"""
total_neurons = []
total_neurons.extend(self._neurons)
sub_neurons = flatten_nested_list(
[subnet.neurons for subnet in self._subnetworks]
)
total_neurons.extend(sub_neurons)
return total_neurons
@property
def subnetworks(self) -> list[Self]:
"""
Returns:
list[SpikingNetworkModule]: Submodules contained in this module.
"""
return self._subnetworks
@property
def instance_count(self) -> int:
"""
Returns:
int: Instance index assigned at construction.
"""
return self._instance_count
[docs]
def recurse_neurons_with_module_uid(self) -> list[dict[ExplicitNeuron, str]]:
"""
Recursively build a list of dictionaries mapping each neuron to its module UID.
Returns:
list[dict[ExplicitNeuron, str]]: One dictionary per neuron/module pair.
"""
total_neurons_with_module = []
total_neurons_with_module = [
{neuron: self.uid} for neuron in self.top_module_neurons
]
sub_neurons = flatten_nested_list(
[subnet.recurse_neurons_with_module_uid() for subnet in self._subnetworks]
)
total_neurons_with_module.extend(sub_neurons)
return total_neurons_with_module
@property
def neurons_with_module_uid(self) -> dict[ExplicitNeuron, str]:
"""
Get a mapping from all neurons in the hierarchy to their parent module UID.
Returns:
dict[ExplicitNeuron, str]: Mapping from neuron to module UID.
"""
dicts = self.recurse_neurons_with_module_uid()
combined = {}
for d in dicts:
combined.update(d)
return combined
@property
def top_module_neurons(self) -> list[ExplicitNeuron]:
"""
Returns neurons belonging to current module, without taking submodules into account
"""
return self._neurons
[docs]
def add_neuron(
self,
Vt: float,
tm: float,
tf: float,
Vreset: float = 0.0,
neuron_name: Optional[str] = None,
) -> ExplicitNeuron:
"""
Create and add a neuron to this module.
Args:
Vt (float): Threshold voltage.
tm (float): Membrane time constant.
tf (float): Synaptic decay time constant.
Vreset (float, optional): Reset voltage after spike. Defaults to 0.0.
neuron_name (str, optional): Optional name for this neuron.
Returns:
ExplicitNeuron: The newly created neuron.
"""
new_neuron = ExplicitNeuron(
Vt=Vt,
tm=tm,
tf=tf,
Vreset=Vreset,
neuron_name=neuron_name,
parent_mod_id=self.instance_count,
)
self._neurons.append(new_neuron)
return new_neuron
[docs]
def add_subnetwork(self, subnet: "SpikingNetworkModule") -> None:
"""
Add a nested spiking network module.
Args:
subnet (SpikingNetworkModule): The submodule to add.
"""
self._subnetworks.append(subnet)
[docs]
def connect_neurons(
self,
pre_neuron: ExplicitNeuron,
post_neuron: ExplicitNeuron,
synapse_type: str,
weight: float,
delay: float,
):
"""
Connect two neurons via a synapse.
Args:
pre_neuron (ExplicitNeuron): Presynaptic neuron.
post_neuron (ExplicitNeuron): Postsynaptic neuron.
synapse_type (str): Type of synapse ('V', 'ge', 'gf', 'gate', etc.).
weight (float): Synaptic weight.
delay (float): Synaptic delay in seconds.
"""
synapse = Synapse(
pre_neuron=pre_neuron,
post_neuron=post_neuron,
synapse_type=synapse_type,
weight=weight,
delay=delay,
)
pre_neuron.out_synapses.append(synapse)