spark.nn#

Submodules#

Classes#

Module

Base class for Spark Modules

Config

Default class for module configuration.

BaseConfig

Base class for module configuration.

Brain

Abstract brain model.

BrainConfig

Configuration class for Brain's.

Package Contents#

class spark.nn.Module(*, config=None, name=None, **kwargs)[source]#

Bases: flax.nnx.Module, abc.ABC, Generic[ConfigT, InputT]

Base class for Spark Modules

Parameters:
  • config (ConfigT | None)

  • name (str | None)

name: str = 'name'[source]#
config: ConfigT[source]#
default_config: type[ConfigT][source]#
classmethod __init_subclass__(**kwargs)[source]#
__built__: bool = False[source]#
__allow_cycles__: bool = False[source]#
classmethod get_config_spec()[source]#

Returns the default configuratio class associated with this module.

Return type:

type[spark.core.config.BaseSparkConfig]

build(input_specs)[source]#

Build method.

Parameters:

input_specs (dict[str, spark.core.specs.InputSpec])

Return type:

None

reset()[source]#

Reset module to its default state.

set_recurrent_shape_contract(shape=None, output_shapes=None)[source]#

Recurrent shape policy pre-defines expected shapes for the output specs.

This is function is a binding contract that allows the modules to accept self connections.

Input:

shape: tuple[int, …], A common shape for all the outputs. output_shapes: dict[str, tuple[int, …]], A specific policy for every single output variable.

NOTE: If both, shape and output_specs, are provided, output_specs takes preference over shape.

Parameters:
Return type:

None

get_recurrent_shape_contract()[source]#

Retrieve the recurrent shape policy of the module.

get_input_specs()[source]#

Returns a dictionary of the SparkModule’s input port specifications.

Return type:

dict[str, spark.core.specs.InputSpec]

get_output_specs()[source]#

Returns a dictionary of the SparkModule’s input port specifications.

Return type:

dict[str, spark.core.specs.OutputSpec]

get_rng_keys(num_keys)[source]#

Generates a new collection of random keys for the JAX’s random engine.

Parameters:

num_keys (int)

Return type:

jax.Array | list[jax.Array]

abstractmethod __call__(**kwargs)[source]#

Execution method.

Parameters:

kwargs (InputT)

Return type:

ModuleOutput

class spark.nn.Config(**kwargs)[source]#

Bases: BaseSparkConfig

Default class for module configuration.

seed: int[source]#
dtype: jax.typing.DTypeLike[source]#
dt: float[source]#
class spark.nn.BaseConfig(**kwargs)[source]#

Bases: abc.ABC

Base class for module configuration.

__config_delimiter__: str = '__'[source]#
__shared_config_delimiter__: str = '_s_'[source]#
__graph_editor_metadata__: dict[source]#
classmethod __init_subclass__(**kwargs)[source]#
merge(partial={})[source]#

Update config with partial overrides.

Parameters:

partial (dict[str, Any])

Return type:

None

diff(other)[source]#

Return differences from another config.

Parameters:

other (BaseSparkConfig)

Return type:

dict[str, Any]

validate()[source]#

Validates all fields in the configuration class.

Return type:

None

get_field_errors(field_name)[source]#

Validates all fields in the configuration class.

Parameters:

field_name (str)

Return type:

list[str]

get_metadata()[source]#

Returns all the metadata in the configuration class, indexed by the attribute name.

Return type:

dict[str, Any]

property class_ref: type[source]#

Returns the type of the associated Module/Initializer.

NOTE: It is recommended to set the __class_ref__ to the name of the associated module/initializer when defining custom configuration classes. The current class_ref solver is extremely brittle and likely to fail in many different custom scenarios.

Return type:

type

__post_init__()[source]#
to_dict()[source]#

Serialize config to dictionary

Return type:

dict[str, dict[str, Any]]

classmethod from_dict(dct)[source]#

Create config instance from dictionary.

Parameters:

dct (dict)

Return type:

BaseSparkConfig

to_file(file_path)[source]#

Export a config instance from a .scfg file.

Parameters:

file_path (str)

Return type:

None

classmethod from_file(file_path)[source]#

Create config instance from a .scfg file.

Parameters:

file_path (str)

Return type:

None

__iter__()[source]#

Custom iterator to simplify SparkConfig inspection across the entire ecosystem. This iterator excludes private fields.

Output:

field_name: str, field name field_value: tp.Any, field value

Return type:

Iterator[tuple[str, dataclasses.Field, Any]]

get_tree_structure()[source]#
Return type:

str

class spark.nn.Brain(config=None, **kwargs)[source]#

Bases: spark.core.module.SparkModule

Abstract brain model. This is more a convenience class used to synchronize data more easily.

Parameters:

config (BrainConfig)

config: BrainConfig[source]#
resolve_initialization_order()[source]#

Resolves the initialization order of the modules.

build(input_specs)[source]#

Build method.

Parameters:

input_specs (dict[str, spark.core.specs.InputSpec])

reset()[source]#

Resets all the modules to its initial state.

__call__(**inputs)[source]#

Update brain’s states.

Parameters:

inputs (spark.core.payloads.SparkPayload)

Return type:

tuple[spark.core.payloads.SparkPayload]

get_spikes_from_cache()[source]#

Collect the brain’s spikes.

Return type:

dict

class spark.nn.BrainConfig(**kwargs)[source]#

Bases: spark.core.config.BaseSparkConfig

Configuration class for Brain’s.

input_map: dict[str, spark.core.specs.InputSpec][source]#
output_map: dict[str, dict][source]#
modules_map: dict[str, spark.core.specs.ModuleSpecs][source]#
validate()[source]#

Validates all fields in the configuration class.

Return type:

None