spark

Contents

spark#

Submodules#

Attributes#

register_module

Decorator used to register a new SparkModule.

register_initializer

Decorator used to register a new Initializer.

register_payload

Decorator used to register a new SparkPayload.

register_config

Decorator used to register a new SparkConfig.

register_cfg_validator

Decorator used to register a new ConfigurationValidator.

REGISTRY

Registry singleton.

Classes#

Constant

Jax.Array wrapper for constant arrays.

Variable

The base class for all Variable types.

SparkPayload

Abstract payload definition to validate exchanges between SparkModule's.

SpikeArray

Representation of a collection of spike events.

CurrentArray

Representation of a collection of currents.

PotentialArray

Representation of a collection of membrane potentials.

FloatArray

Representation of a float array.

IntegerArray

Representation of an integer array.

BooleanMask

Representation of an inhibitory boolean mask.

PortSpecs

Base specification for a port of an SparkModule.

PortMap

Specification for an output port of an SparkModule.

ModuleSpecs

Specification for SparkModule automatic constructor.

GraphEditor

Functions#

split(node, *filters)

Wrapper around flax.nnx.split to simply imports.

merge(graphdef, state, /, *states)

Wrapper around flax.nnx.merge to simply imports.

Package Contents#

class spark.Constant(data, dtype=None)[source]#

Jax.Array wrapper for constant arrays.

Parameters:
  • data (Any)

  • dtype (Any)

value: jax.Array[source]#
__jax_array__()[source]#
Return type:

jax.Array

__array__(dtype=None)[source]#
Return type:

numpy.ndarray

property shape: tuple[int, Ellipsis][source]#
Return type:

tuple[int, Ellipsis]

property dtype: Any[source]#
Return type:

Any

property ndim: int[source]#
Return type:

int

property size: int[source]#
Return type:

int

property T: jax.Array[source]#
Return type:

jax.Array

__neg__()[source]#
Return type:

jax.Array

__pos__()[source]#
Return type:

jax.Array

__abs__()[source]#
Return type:

jax.Array

__invert__()[source]#
Return type:

jax.Array

__add__(other)[source]#
Return type:

jax.Array

__sub__(other)[source]#
Return type:

jax.Array

__mul__(other)[source]#
Return type:

jax.Array

__truediv__(other)[source]#
Return type:

jax.Array

__floordiv__(other)[source]#
Return type:

jax.Array

__mod__(other)[source]#
Return type:

jax.Array

__matmul__(other)[source]#
Return type:

jax.Array

__pow__(other)[source]#
Return type:

jax.Array

__radd__(other)[source]#
Return type:

jax.Array

__rsub__(other)[source]#
Return type:

jax.Array

__rmul__(other)[source]#
Return type:

jax.Array

__rtruediv__(other)[source]#
Return type:

jax.Array

__rfloordiv__(other)[source]#
Return type:

jax.Array

__rmod__(other)[source]#
Return type:

jax.Array

__rmatmul__(other)[source]#
Return type:

jax.Array

__rpow__(other)[source]#
Return type:

jax.Array

class spark.Variable(value, dtype=None, **metadata)[source]#

Bases: flax.nnx.Variable

The base class for all Variable types. Note that this is just a convinience wrapper around Flax’s nnx.Variable to simplify imports.

Parameters:
  • value (Any)

  • dtype (Any)

value: jax.Array[source]#
__jax_array__()[source]#
Return type:

jax.Array

__array__(dtype=None)[source]#
Return type:

numpy.ndarray

property shape: tuple[int, Ellipsis][source]#
Return type:

tuple[int, Ellipsis]

class spark.SparkPayload[source]#

Bases: abc.ABC

Abstract payload definition to validate exchanges between SparkModule’s.

tree_flatten()[source]#
Return type:

tuple[tuple[jax.Array], None]

classmethod tree_unflatten(aux_data, children)[source]#
Return type:

Self

property shape: Any[source]#
Return type:

Any

property dtype: Any[source]#
Return type:

Any

class spark.SpikeArray(spikes, inhibition_mask=False, async_spikes=False)[source]#

Bases: SparkPayload

Representation of a collection of spike events.

Init:

spikes: jax.Array[bool], True if neuron spiked, False otherwise inhibition_mask: jax.Array[bool], True if neuron is inhibitory, False otherwise

The async_spikes flag is automatically set True by delay mechanisms that perform neuron-to-neuron specific delays. Note that when async_spikes is True the shape of the spikes changes from (origin_units,) to (origin_units, target_units). This is important when implementing new synaptic models, since fully valid synaptic models should be able to handle both cases.

Parameters:
async_spikes: bool = False[source]#
tree_flatten()[source]#
Return type:

tuple[tuple, tuple]

classmethod tree_unflatten(aux_data, children)[source]#
Return type:

Self

__jax_array__()[source]#
Return type:

jax.numpy.ndarray

__array__(dtype=None)[source]#
Return type:

numpy.ndarray

property spikes: jax.Array[source]#
Return type:

jax.Array

property inhibition_mask: jax.Array[source]#
Return type:

jax.Array

property value: jax.Array[source]#
Return type:

jax.Array

property shape: tuple[int, Ellipsis][source]#
Return type:

tuple[int, Ellipsis]

property dtype: jax.typing.DTypeLike[source]#
Return type:

jax.typing.DTypeLike

class spark.CurrentArray[source]#

Bases: ValueSparkPayload

Representation of a collection of currents.

class spark.PotentialArray[source]#

Bases: ValueSparkPayload

Representation of a collection of membrane potentials.

class spark.FloatArray[source]#

Bases: ValueSparkPayload

Representation of a float array.

class spark.IntegerArray[source]#

Bases: ValueSparkPayload

Representation of an integer array.

class spark.BooleanMask[source]#

Bases: ValueSparkPayload

Representation of an inhibitory boolean mask.

class spark.PortSpecs(payload_type, shape, dtype, description=None, async_spikes=None, inhibition_mask=None)[source]#

Base specification for a port of an SparkModule.

Parameters:
payload_type: type[spark.core.payloads.SparkPayload] | None[source]#
shape: tuple[int, Ellipsis] | list[tuple[int, Ellipsis]] | None[source]#
dtype: jax.typing.DTypeLike | None[source]#
description: str | None = None[source]#
async_spikes: bool | None = (None,)[source]#
inhibition_mask: bool | None = (None,)[source]#
to_dict(is_partial=False)[source]#

Serialize PortSpecs to dictionary

Parameters:

is_partial (bool)

Return type:

dict[str, Any]

classmethod from_dict(dct, is_partial=False)[source]#

Deserialize dictionary to PortSpecs

Parameters:
Return type:

Self

classmethod from_portspecs_list(portspec_list, validate_async=True)[source]#

Merges a list of PortSpecs into a single PortSpecs

Parameters:
Return type:

Self

class spark.PortMap(origin, port)[source]#

Specification for an output port of an SparkModule.

Parameters:
origin: str[source]#
port: str[source]#
to_dict(is_partial=False)[source]#

Serialize PortMap to dictionary

Parameters:

is_partial (bool)

Return type:

dict[str, Any]

classmethod from_dict(dct, is_partial=False)[source]#

Deserialize dictionary to PortMap

Parameters:
Return type:

Self

class spark.ModuleSpecs(name, module_cls, inputs, config)[source]#

Specification for SparkModule automatic constructor.

Parameters:
name: str[source]#
module_cls: type[spark.core.module.SparkModule][source]#
inputs: dict[str, list[PortMap]][source]#
config: spark.core.config.BaseSparkConfig[source]#
to_dict(is_partial=False)[source]#

Serialize ModuleSpecs to dictionary

Parameters:

is_partial (bool)

Return type:

dict[str, Any]

classmethod from_dict(dct, is_partial=False)[source]#

Deserialize dictionary to ModuleSpecs

Parameters:
Return type:

Self

spark.split(node, *filters)[source]#

Wrapper around flax.nnx.split to simply imports.

Parameters:
  • node (A)

  • filters (flax.nnx.filterlib.Filter)

Return type:

tuple[flax.nnx.graph.GraphDef[A], flax.nnx.graph.GraphState | flax.nnx.variablelib.VariableState, typing_extensions.Unpack[tuple[flax.nnx.graph.GraphState | flax.nnx.variablelib.VariableState, Ellipsis]]]

spark.merge(graphdef, state, /, *states)[source]#

Wrapper around flax.nnx.merge to simply imports.

Parameters:
Return type:

A

spark.register_module[source]#

Decorator used to register a new SparkModule. Note that module must inherit from spark.nn.Module (spark.core.module.SparkModule)

spark.register_initializer[source]#

Decorator used to register a new Initializer. Note that module must inherit from spark.nn.initializers.base.Initializer

spark.register_payload[source]#

Decorator used to register a new SparkPayload. Note that module must inherit from spark.SparkPayload (spark.core.payloads.SparkPayload)

spark.register_config[source]#

Decorator used to register a new SparkConfig. Note that module must inherit from spark.nn.BaseConfig (spark.core.config.BaseSparkConfig)

spark.register_cfg_validator[source]#

Decorator used to register a new ConfigurationValidator. Note that module must inherit from spark.core.config_validation.ConfigurationValidator

spark.REGISTRY[source]#

Registry singleton.

class spark.GraphEditor[source]#
app[source]#
launch()[source]#

Creates and shows the editor window without blocking. This method is safe to call multiple times.

Return type:

None

exit_editor()[source]#

Exit editor.

Return type:

None

closeEvent(event)[source]#

Overrides the default close event to check for unsaved changes.

Return type:

None

new_session()[source]#

Clears the current session after checking for unsaved changes.

Return type:

None

save_session()[source]#

Saves the current session to a Spark Graph Editor file.

Return type:

bool

save_session_as()[source]#

Saves the current session to a new Spark Graph Editor file.

Return type:

bool

load_session()[source]#

Loads a graph state from a Spark Graph Editor file after checking for unsaved changes.

Return type:

None

load_from_model()[source]#

Loads a graph state from a Spark configuration file after checking for unsaved changes.

Return type:

None

export_model()[source]#

Exports the graph state to a Spark configuration file.

Return type:

bool

export_model_as()[source]#

Exports the graph state to a new Spark configuration file.

Return type:

bool