spark.nn.components.synapses.traced#

Classes#

TracedSynapsesConfig

TracedSynapses model configuration class.

TracedSynapses

Traced synaptic model.

RDTracedSynapsesConfig

RDTracedSynapses model configuration class.

RDTracedSynapses

Rise-Decay traced synaptic model.

RFSTracedSynapsesConfig

RFSTracedSynapses model configuration class.

RFSTracedSynapses

Traced synaptic model.

Module Contents#

class spark.nn.components.synapses.traced.TracedSynapsesConfig(__skip_validation__=False, **kwargs)[source]#

Bases: spark.nn.components.synapses.linear.LinearSynapsesConfig

TracedSynapses model configuration class.

Parameters:

__skip_validation__ (bool)

tau: float | jax.Array | spark.nn.initializers.base.Initializer[source]#
scale: float | jax.Array[source]#
base: float | jax.Array[source]#
class spark.nn.components.synapses.traced.TracedSynapses(config=None, **kwargs)[source]#

Bases: spark.nn.components.synapses.linear.LinearSynapses

Traced synaptic model. Output currents are computed as the trace of the dot product of the kernel with the input spikes.

Init:

units: tuple[int, …] kernel: KernelInitializerConfig tau: float | jax.Array scale: float | jax.Array base: float | jax.Array

Input:

spikes: SpikeArray

Output:

currents: CurrentArray

Parameters:

config (TracedSynapsesConfig | None)

config: TracedSynapsesConfig[source]#
current_tracer: spark.core.tracers.Tracer[source]#
build(input_specs)[source]#

Build method.

Parameters:

input_specs (dict[str, spark.core.specs.PortSpecs])

reset()[source]#

Resets component state.

Return type:

None

class spark.nn.components.synapses.traced.RDTracedSynapsesConfig(__skip_validation__=False, **kwargs)[source]#

Bases: spark.nn.components.synapses.linear.LinearSynapsesConfig

RDTracedSynapses model configuration class.

Parameters:

__skip_validation__ (bool)

tau_rise: float | jax.Array | spark.nn.initializers.base.Initializer[source]#
scale_rise: float | jax.Array[source]#
base_rise: float | jax.Array[source]#
tau_decay: float | jax.Array | spark.nn.initializers.base.Initializer[source]#
scale_decay: float | jax.Array[source]#
base_decay: float | jax.Array[source]#
class spark.nn.components.synapses.traced.RDTracedSynapses(config=None, **kwargs)[source]#

Bases: spark.nn.components.synapses.linear.LinearSynapses

Rise-Decay traced synaptic model. Output currents are computed as the RDTrace of the dot product of the kernel with the input spikes.

Init:

units: tuple[int, …] kernel: KernelInitializerConfig tau_rise: float | jax.Array scale_rise: float | jax.Array base_rise: float | jax.Array tau_decay: float | jax.Array scale_decay: float | jax.Array base_decay: float | jax.Array

Input:

spikes: SpikeArray

Output:

currents: CurrentArray

Parameters:

config (RDTracedSynapsesConfig | None)

config: RDTracedSynapsesConfig[source]#
current_tracer: spark.core.tracers.RDTracer[source]#
build(input_specs)[source]#

Build method.

Parameters:

input_specs (dict[str, spark.core.specs.PortSpecs])

reset()[source]#

Resets component state.

Return type:

None

class spark.nn.components.synapses.traced.RFSTracedSynapsesConfig(__skip_validation__=False, **kwargs)[source]#

Bases: spark.nn.components.synapses.linear.LinearSynapsesConfig

RFSTracedSynapses model configuration class.

Parameters:

__skip_validation__ (bool)

alpha: float | jax.Array[source]#
tau_rise: float | jax.Array | spark.nn.initializers.base.Initializer[source]#
scale_rise: float | jax.Array[source]#
base_rise: float | jax.Array[source]#
tau_fast_decay: float | jax.Array | spark.nn.initializers.base.Initializer[source]#
scale_fast_decay: float | jax.Array[source]#
base_fast_decay: float | jax.Array[source]#
tau_slow_decay: float | jax.Array | spark.nn.initializers.base.Initializer[source]#
scale_slow_decay: float | jax.Array[source]#
base_slow_decay: float | jax.Array[source]#
class spark.nn.components.synapses.traced.RFSTracedSynapses(config=None, **kwargs)[source]#

Bases: spark.nn.components.synapses.linear.LinearSynapses

Traced synaptic model. Output currents are computed as the trace of the dot product of the kernel with the input spikes.

Init:

units: tuple[int, …] kernel: KernelInitializerConfig alpha: float | jax.Array tau_rise: float | jax.Array scale_rise: float | jax.Array base_rise: float | jax.Array tau_fast_decay: float | jax.Array scale_fast_decay: float | jax.Array base_fast_decay: float | jax.Array tau_slow_decay: float | jax.Array scale_slow_decay: float | jax.Array base_slow_decay: float | jax.Array

Input:

spikes: SpikeArray

Output:

currents: CurrentArray

Parameters:

config (RFSTracedSynapsesConfig | None)

config: RFSTracedSynapsesConfig[source]#
current_tracer: spark.core.tracers.RDTracer[source]#
build(input_specs)[source]#

Build method.

Parameters:

input_specs (dict[str, spark.core.specs.PortSpecs])

reset()[source]#

Resets component state.

Return type:

None