Source code for axon_sdk.networks.functional.signflip
from axon_sdk.primitives import SpikingNetworkModule, DataEncoder
from typing import Optional
[docs]
class SignFlipperNetwork(SpikingNetworkModule):
def __init__(self, encoder: DataEncoder, module_name: Optional[str] = None):
super().__init__(module_name)
self.encoder = encoder
# Parameters
Vt = 10.0
tm = 100.0
tf = 20.0
Tsyn = 1.0
Tneu = 0.01
Tmin = encoder.Tmin
we = Vt
wi = -Vt
wacc = Vt * tm / encoder.Tmax
wacc_bar = Vt * tm / encoder.Tcod
gmult = Vt * tm / tf
self.inp_plus = self.add_neuron(Vt=Vt, tf=tf, tm=tm, neuron_name='inp_plus')
self.inp_minus = self.add_neuron(Vt=Vt, tf=tf, tm=tm, neuron_name='inp_minus')
self.outp_plus = self.add_neuron(Vt=Vt, tf=tf, tm=tm, neuron_name='outp_plus')
self.outp_minus = self.add_neuron(Vt=Vt, tf=tf, tm=tm, neuron_name='outp_minus')
self.connect_neurons(self.inp_plus, self.outp_minus, "V", we, Tsyn)
self.connect_neurons(self.inp_minus, self.outp_plus, "V", we, Tsyn)
if __name__ == "__main__":
from axon_sdk.simulator import Simulator
enc = DataEncoder()
net = SignFlipperNetwork(encoder=enc, module_name='sign_flip_net')
sim = Simulator(net, enc, dt=0.001)
value = +0.5
print(f"Input value: {value}")
if value >= 0:
sim.apply_input_value(abs(value), net.inp_plus)
else:
sim.apply_input_value(abs(value), net.inp_minus)
sim.simulate(300)
spikes_plus = sim.spike_log.get(net.outp_plus.uid, [])
spikes_minus = sim.spike_log.get(net.outp_minus.uid, [])
if len(spikes_plus) == 2:
print(f"Got {len(spikes_plus)} spikes in output plus")
assert len(spikes_plus) == 2, "Didn't get 2 spikes in output plus!!"
interval = spikes_plus[1] - spikes_plus[0]
print(f"Decoded value: {enc.decode_interval(interval)}")
if len(spikes_minus) == 2:
print(f"Got {len(spikes_minus)} spikes in output minus")
assert len(spikes_minus) == 2, "Didn't get 2 spikes in output minus!!"
interval = spikes_minus[1] - spikes_minus[0]
print(f"Decoded value: -{enc.decode_interval(interval)}")