Source code for axon_sdk.visualization.chronogram
import matplotlib.pyplot as plt
from matplotlib.pyplot import cm
[docs]
def build_array(length, entry_points, fill_method="ffil"):
"""
Build an array of specified length using the provided entry points.
Args:
length (int): Desired length of the output array
entry_points (list): List of tuples (x, t) where x is a value and t is an index
fill_method (str): Method to fill missing values. Options are 'zero' or 'ffill' (forward fill)
Returns:
list: An array of the specified length with values at entry points and filled values elsewhere
"""
# Initialize array with zeros
result = [0] * length
# Sort entry points by index
sorted_entries = sorted(entry_points, key=lambda pair: pair[1])
# Process entry points
for x, t in sorted_entries:
if 0 <= t < length: # Check if index is within bounds
result[t] = x
# If using forward fill
if fill_method == "ffill":
last_valid_value = None
for i in range(length):
if result[i] != 0:
last_valid_value = result[i]
elif last_valid_value is not None:
result[i] = last_valid_value
return result
[docs]
def plot_chronogram(
timesteps: list[float],
voltage_log: dict[str, list[tuple]],
spike_log: dict[str, list[float]],
):
print("Launching chronogram visualization...")
print("=========================================")
n = len(voltage_log.keys())
_, ax = plt.subplots(nrows=n, ncols=1, sharex=True, figsize=(10, 5))
values = [
i / (n - 1) for i in range(n)
] # linearly spaced values between 0 and 1
colors = iter(cm.rainbow(values))
for i, item in enumerate(voltage_log.keys()):
c = next(colors)
v_log = voltage_log[item]
if len(timesteps) != len(v_log):
v_log = build_array(len(timesteps), v_log)
ax[i].plot(timesteps, v_log, c=c)
# Neuron names
ax[i].set_ylabel(item, rotation=0, labelpad=30)
# Voltage limits
ax[i].set_ylim(-15, 15)
# Removing extra graphical elements (axes ticks and graph spines)
ax[i].get_yaxis().set_ticks([])
ax[i].xaxis.set_tick_params(length=0)
ax[i].spines["top"].set_visible(False)
ax[i].spines["right"].set_visible(False)
ax[i].spines["bottom"].set_visible(False)
ax[i].spines["left"].set_visible(False)
if item in spike_log:
for spike in spike_log[item]:
ax[i].scatter(spike, 0, s=20, c=c)
plt.show()
print("=========================================")
if __name__ == "__main__":
import numpy as np
def build_array_with_padding(desired_length, valid_entry_points):
# Initialize an array of zeros with the desired length
result_array = np.zeros(desired_length)
# Convert the list of valid entry points to a numpy array for efficient indexing
x_values = np.array([x for x, _ in valid_entry_points])
t_indices = np.array([t for _, t in valid_entry_points])
# Find the last valid x value before each timestep
cummax_x = np.maximum.accumulate(x_values)
cummax_t = np.maximum.accumulate(t_indices)
# Create a mask for where valid entries are present
valid_mask = np.zeros(desired_length, dtype=bool)
valid_mask[t_indices] = True
# Fill the array with the last valid x value encountered up to each timestep
result_array[: cummax_t[-1] + 1] = cummax_x[-1]
np.maximum.at(result_array, t_indices, x_values)
return result_array
# Example usage:
desired_length = 10
valid_entry_points = [(2, 1), (3, 2), (4, 3)]
result = build_array_with_padding(desired_length, valid_entry_points)
print(result) # Output: [0. 0. 0. 2. 3. 4. 4. 4. 4. 4.]