spark.nn.components.synapses

spark.nn.components.synapses#

Submodules#

Classes#

Synanpses

Abstract synapse model.

SynanpsesOutput

Generic synapses model output spec.

LinearSynapses

Linea synaptic model.

LinearSynapsesConfig

LinearSynapses model configuration class.

TracedSynapses

Traced synaptic model.

TracedSynapsesConfig

TracedSynapses model configuration class.

RDTracedSynapses

Rise-Decay traced synaptic model.

RDTracedSynapsesConfig

RDTracedSynapses model configuration class.

RFSTracedSynapses

Traced synaptic model.

RFSTracedSynapsesConfig

RFSTracedSynapses model configuration class.

Package Contents#

class spark.nn.components.synapses.Synanpses(config=None, **kwargs)[source]#

Bases: spark.nn.components.base.Component, Generic[ConfigT]

Abstract synapse model.

Note that we require the kernel entries to be in pA for numerical stability, since most of the time we want to run in half-precision. However somas expect the current in nA so we need to rescale the output.

Init:

Input:

spikes: SpikeArray

Output:

currents: CurrentArray

Parameters:

config (ConfigT | None)

abstractmethod get_kernel()[source]#
Return type:

spark.core.payloads.FloatArray

abstractmethod set_kernel(new_kernel)[source]#
Parameters:

new_kernel (spark.core.payloads.FloatArray)

Return type:

None

__call__(spikes)[source]#

Compute synanpse’s currents.

Parameters:

spikes (spark.core.payloads.SpikeArray)

Return type:

SynanpsesOutput

class spark.nn.components.synapses.SynanpsesOutput[source]#

Bases: TypedDict

Generic synapses model output spec.

Initialize self. See help(type(self)) for accurate signature.

currents: spark.core.payloads.CurrentArray[source]#
class spark.nn.components.synapses.LinearSynapses(config=None, **kwargs)[source]#

Bases: spark.nn.components.synapses.base.Synanpses

Linea synaptic model. Output currents are computed as the dot product of the kernel with the input spikes.

Init:

units: tuple[int, …] kernel: jax.Array | Initializer

Input:

spikes: SpikeArray

Output:

currents: CurrentArray

Reference:

Neuronal Dynamics: From Single Neurons to Networks and Models of Cognition. Gerstner W, Kistler WM, Naud R, Paninski L. Chapter 1.3 Integrate-And-Fire Models https://neuronaldynamics.epfl.ch/online/Ch1.S3.html

Parameters:

config (LinearSynapses | None)

config: LinearSynapsesConfig[source]#
build(input_specs)[source]#

Build method.

Parameters:

input_specs (dict[str, spark.core.specs.PortSpecs])

get_kernel()[source]#
Return type:

spark.core.payloads.FloatArray

get_flat_kernel()[source]#
Return type:

spark.core.payloads.FloatArray

set_kernel(new_kernel)[source]#
Parameters:

new_kernel (spark.core.payloads.FloatArray)

Return type:

None

class spark.nn.components.synapses.LinearSynapsesConfig(__skip_validation__=False, **kwargs)[source]#

Bases: spark.nn.components.synapses.base.SynanpsesConfig

LinearSynapses model configuration class.

Parameters:

__skip_validation__ (bool)

units: tuple[int, Ellipsis][source]#
kernel: jax.Array | spark.nn.initializers.base.Initializer[source]#
class spark.nn.components.synapses.TracedSynapses(config=None, **kwargs)[source]#

Bases: spark.nn.components.synapses.linear.LinearSynapses

Traced synaptic model. Output currents are computed as the trace of the dot product of the kernel with the input spikes.

Init:

units: tuple[int, …] kernel: KernelInitializerConfig tau: float | jax.Array scale: float | jax.Array base: float | jax.Array

Input:

spikes: SpikeArray

Output:

currents: CurrentArray

Parameters:

config (TracedSynapsesConfig | None)

config: TracedSynapsesConfig[source]#
current_tracer: spark.core.tracers.Tracer[source]#
build(input_specs)[source]#

Build method.

Parameters:

input_specs (dict[str, spark.core.specs.PortSpecs])

reset()[source]#

Resets component state.

Return type:

None

class spark.nn.components.synapses.TracedSynapsesConfig(__skip_validation__=False, **kwargs)[source]#

Bases: spark.nn.components.synapses.linear.LinearSynapsesConfig

TracedSynapses model configuration class.

Parameters:

__skip_validation__ (bool)

tau: float | jax.Array | spark.nn.initializers.base.Initializer[source]#
scale: float | jax.Array[source]#
base: float | jax.Array[source]#
class spark.nn.components.synapses.RDTracedSynapses(config=None, **kwargs)[source]#

Bases: spark.nn.components.synapses.linear.LinearSynapses

Rise-Decay traced synaptic model. Output currents are computed as the RDTrace of the dot product of the kernel with the input spikes.

Init:

units: tuple[int, …] kernel: KernelInitializerConfig tau_rise: float | jax.Array scale_rise: float | jax.Array base_rise: float | jax.Array tau_decay: float | jax.Array scale_decay: float | jax.Array base_decay: float | jax.Array

Input:

spikes: SpikeArray

Output:

currents: CurrentArray

Parameters:

config (RDTracedSynapsesConfig | None)

config: RDTracedSynapsesConfig[source]#
current_tracer: spark.core.tracers.RDTracer[source]#
build(input_specs)[source]#

Build method.

Parameters:

input_specs (dict[str, spark.core.specs.PortSpecs])

reset()[source]#

Resets component state.

Return type:

None

class spark.nn.components.synapses.RDTracedSynapsesConfig(__skip_validation__=False, **kwargs)[source]#

Bases: spark.nn.components.synapses.linear.LinearSynapsesConfig

RDTracedSynapses model configuration class.

Parameters:

__skip_validation__ (bool)

tau_rise: float | jax.Array | spark.nn.initializers.base.Initializer[source]#
scale_rise: float | jax.Array[source]#
base_rise: float | jax.Array[source]#
tau_decay: float | jax.Array | spark.nn.initializers.base.Initializer[source]#
scale_decay: float | jax.Array[source]#
base_decay: float | jax.Array[source]#
class spark.nn.components.synapses.RFSTracedSynapses(config=None, **kwargs)[source]#

Bases: spark.nn.components.synapses.linear.LinearSynapses

Traced synaptic model. Output currents are computed as the trace of the dot product of the kernel with the input spikes.

Init:

units: tuple[int, …] kernel: KernelInitializerConfig alpha: float | jax.Array tau_rise: float | jax.Array scale_rise: float | jax.Array base_rise: float | jax.Array tau_fast_decay: float | jax.Array scale_fast_decay: float | jax.Array base_fast_decay: float | jax.Array tau_slow_decay: float | jax.Array scale_slow_decay: float | jax.Array base_slow_decay: float | jax.Array

Input:

spikes: SpikeArray

Output:

currents: CurrentArray

Parameters:

config (RFSTracedSynapsesConfig | None)

config: RFSTracedSynapsesConfig[source]#
current_tracer: spark.core.tracers.RDTracer[source]#
build(input_specs)[source]#

Build method.

Parameters:

input_specs (dict[str, spark.core.specs.PortSpecs])

reset()[source]#

Resets component state.

Return type:

None

class spark.nn.components.synapses.RFSTracedSynapsesConfig(__skip_validation__=False, **kwargs)[source]#

Bases: spark.nn.components.synapses.linear.LinearSynapsesConfig

RFSTracedSynapses model configuration class.

Parameters:

__skip_validation__ (bool)

alpha: float | jax.Array[source]#
tau_rise: float | jax.Array | spark.nn.initializers.base.Initializer[source]#
scale_rise: float | jax.Array[source]#
base_rise: float | jax.Array[source]#
tau_fast_decay: float | jax.Array | spark.nn.initializers.base.Initializer[source]#
scale_fast_decay: float | jax.Array[source]#
base_fast_decay: float | jax.Array[source]#
tau_slow_decay: float | jax.Array | spark.nn.initializers.base.Initializer[source]#
scale_slow_decay: float | jax.Array[source]#
base_slow_decay: float | jax.Array[source]#