spark.nn#

Submodules#

Classes#

Module

Base class for Spark Modules

Config

Default class for module configuration.

BaseConfig

Base class for module configuration.

Brain

Brain model.

BrainConfig

Configuration class for Brain's.

Package Contents#

class spark.nn.Module(*, config=None, name=None, **kwargs)[source]#

Bases: flax.nnx.Module, abc.ABC, Generic[ConfigT, InputT]

Base class for Spark Modules

Parameters:
  • config (ConfigT | None)

  • name (str | None)

name: str = 'name'[source]#
config: ConfigT[source]#
default_config: type[ConfigT][source]#
classmethod __init_subclass__(**kwargs)[source]#
__built__: bool = False[source]#
__allow_cycles__: bool = False[source]#
classmethod get_config_spec()[source]#

Returns the default configuration class associated with this module.

Return type:

type[spark.core.config.BaseSparkConfig]

build(input_specs)[source]#

Build method.

Parameters:

input_specs (dict[str, spark.core.specs.PortSpecs])

Return type:

None

reset()[source]#

Reset module to its default state.

set_contract_output_specs(contract_specs)[source]#

Recurrent shape policy pre-defines expected shapes for the output specs.

This is function is a binding contract that allows the modules to accept self connections.

Input:

contract_specs: dict[str, PortSpecs], A dictionary with a contract for the output specs.

NOTE: If both, shape and output_specs, are provided, output_specs takes preference over shape.

Parameters:

contract_specs (dict[str, spark.core.specs.PortSpecs])

Return type:

None

get_contract_output_specs()[source]#

Retrieve the recurrent spec policy of the module.

Return type:

dict[str, spark.core.specs.PortSpecs] | None

get_input_specs()[source]#

Returns a dictionary of the SparkModule’s input port specifications.

Return type:

dict[str, spark.core.specs.PortSpecs]

get_output_specs()[source]#

Returns a dictionary of the SparkModule’s input port specifications.

Return type:

dict[str, spark.core.specs.PortSpecs]

get_rng_keys(num_keys)[source]#

Generates a new collection of random keys for the JAX’s random engine.

Parameters:

num_keys (int)

Return type:

jax.Array | list[jax.Array]

abstractmethod __call__(**kwargs)[source]#

Execution method.

Parameters:

kwargs (InputT)

Return type:

ModuleOutput

__repr__()[source]#
inspect()[source]#

Returns a formated string of the datastructure.

Return type:

str

checkpoint(path, overwrite=False)[source]#
Return type:

None

classmethod from_checkpoint(path, safe=True)[source]#
Return type:

SparkModule

class spark.nn.Config(__skip_validation__=False, **kwargs)[source]#

Bases: BaseSparkConfig

Default class for module configuration.

Parameters:

__skip_validation__ (bool)

seed: int[source]#
dtype: jax.typing.DTypeLike[source]#
dt: float[source]#
class spark.nn.BaseConfig(__skip_validation__=False, **kwargs)[source]#

Bases: abc.ABC

Base class for module configuration.

Parameters:

__skip_validation__ (bool)

__config_delimiter__: str = '__'[source]#
__shared_config_delimiter__: str = '_s_'[source]#
__metadata__: dict[source]#
__graph_editor_metadata__: dict[source]#
classmethod __init_subclass__(**kwargs)[source]#
__eq__(other)[source]#
Return type:

bool

merge(partial={}, __skip_validation__=False)[source]#

Update config with partial overrides.

Parameters:
Return type:

None

diff(other)[source]#

Return differences from another config.

Parameters:

other (BaseSparkConfig)

Return type:

dict[str, Any]

validate(is_partial=False, errors=None, current_path=['main'])[source]#

Validates all fields in the configuration class.

Parameters:
Return type:

None

get_field_errors(field_name)[source]#

Validates all fields in the configuration class.

Parameters:

field_name (str)

Return type:

list[str]

get_metadata()[source]#

Returns all the metadata in the configuration class, indexed by the attribute name.

Return type:

dict[str, Any]

property class_ref: type[source]#

Returns the type of the associated Module/Initializer.

NOTE: It is recommended to set the __class_ref__ to the name of the associated module/initializer when defining custom configuration classes. The automatic class_ref solver is extremely brittle and likely to fail in many different custom scenarios.

Return type:

type

__post_init__()[source]#
to_dict(is_partial=False)[source]#

Serialize config to dictionary

Parameters:

is_partial (bool)

Return type:

dict[str, dict[str, Any]]

get_kwargs()[source]#

Returns a dictionary with pairs of key, value fields (skips metadata).

Return type:

dict[str, dict[str, Any]]

classmethod from_dict(dct)[source]#

Create config instance from dictionary.

Parameters:

dct (dict)

Return type:

BaseSparkConfig

to_file(file_path, is_partial=False)[source]#

Export a config instance from a .scfg file.

Parameters:
  • file_path (str)

  • is_partial (bool)

Return type:

None

classmethod from_file(file_path, is_partial=False)[source]#

Create config instance from a .scfg file.

Parameters:
  • file_path (str)

  • is_partial (bool)

Return type:

BaseSparkConfig

__iter__()[source]#

Custom iterator to simplify SparkConfig inspection across the entire ecosystem. This iterator excludes private fields.

Output:

field_name: str, field name field_value: tp.Any, field value

Return type:

Iterator[tuple[str, dataclasses.Field, Any]]

__repr__()[source]#
inspect(simplified=False)[source]#

Returns a formated string of the datastructure.

Return type:

str

with_new_seeds(seed=None)[source]#

Utility method to recompute all seed variables within the SparkConfig. Useful when creating several populations from the same config.

Return type:

BaseSparkConfig

class spark.nn.Brain(config=None, **kwargs)[source]#

Bases: spark.core.module.SparkModule

Brain model.

A brain is a pipeline object used to represent and coordinate a collection of neurons and interfaces. This implementation relies on a cache system to simplify parallel computations; every timestep all the modules in the Brain read from the cache, update its internal state and update the cache state. Note that this introduces a small latency between elements of the brain, which for most cases is negligible, and for such a reason it is recommended that only full neuron models and interfaces are used within a Brain.

Parameters:

config (BrainConfig)

config: BrainConfig[source]#
resolve_initialization_order()[source]#

Resolves the initialization order of the modules.

build(input_specs)[source]#

Build method.

Parameters:

input_specs (dict[str, spark.core.specs.PortSpecs])

reset()[source]#

Resets all the modules to its initial state.

__call__(**inputs)[source]#

Update brain’s states.

Parameters:

inputs (spark.core.payloads.SparkPayload)

Return type:

dict[str, spark.core.payloads.SparkPayload]

get_spikes_from_cache()[source]#

Collect the brain’s spikes.

Return type:

dict

class spark.nn.BrainConfig(__skip_validation__=False, **kwargs)[source]#

Bases: spark.core.config.BaseSparkConfig

Configuration class for Brain’s.

Parameters:

__skip_validation__ (bool)

input_map: dict[str, spark.core.specs.PortSpecs][source]#
output_map: dict[str, dict][source]#
modules_map: dict[str, spark.core.specs.ModuleSpecs][source]#
validate(is_partial=False, errors=None, current_path=['brain'])[source]#

Validates all fields in the configuration class.

Parameters:
Return type:

dict[str] | None

refresh_seeds()[source]#

Utility method to recompute all seed variables within the SparkConfig. Useful when creating several populations from the same config.