spark.nn#
Submodules#
Classes#
Base class for Spark Modules |
|
Default class for module configuration. |
|
Base class for module configuration. |
|
Brain model. |
|
Configuration class for Brain's. |
Package Contents#
- class spark.nn.Module(*, config=None, name=None, **kwargs)[source]#
Bases:
flax.nnx.Module,abc.ABC,Generic[ConfigT,InputT]Base class for Spark Modules
- Parameters:
config (ConfigT | None)
name (str | None)
- classmethod get_config_spec()[source]#
Returns the default configuration class associated with this module.
- Return type:
- build(input_specs)[source]#
Build method.
- Parameters:
input_specs (dict[str, spark.core.specs.PortSpecs])
- Return type:
None
- set_contract_output_specs(contract_specs)[source]#
Recurrent shape policy pre-defines expected shapes for the output specs.
This is function is a binding contract that allows the modules to accept self connections.
- Input:
contract_specs: dict[str, PortSpecs], A dictionary with a contract for the output specs.
NOTE: If both, shape and output_specs, are provided, output_specs takes preference over shape.
- Parameters:
contract_specs (dict[str, spark.core.specs.PortSpecs])
- Return type:
None
- get_contract_output_specs()[source]#
Retrieve the recurrent spec policy of the module.
- Return type:
dict[str, spark.core.specs.PortSpecs] | None
- get_input_specs()[source]#
Returns a dictionary of the SparkModule’s input port specifications.
- Return type:
- get_output_specs()[source]#
Returns a dictionary of the SparkModule’s input port specifications.
- Return type:
- get_rng_keys(num_keys)[source]#
Generates a new collection of random keys for the JAX’s random engine.
- abstractmethod __call__(**kwargs)[source]#
Execution method.
- Parameters:
kwargs (InputT)
- Return type:
- class spark.nn.Config(__skip_validation__=False, **kwargs)[source]#
Bases:
BaseSparkConfigDefault class for module configuration.
- Parameters:
__skip_validation__ (bool)
- class spark.nn.BaseConfig(__skip_validation__=False, **kwargs)[source]#
Bases:
abc.ABCBase class for module configuration.
- Parameters:
__skip_validation__ (bool)
- diff(other)[source]#
Return differences from another config.
- Parameters:
other (BaseSparkConfig)
- Return type:
- validate(is_partial=False, errors=None, current_path=['main'])[source]#
Validates all fields in the configuration class.
- get_metadata()[source]#
Returns all the metadata in the configuration class, indexed by the attribute name.
- property class_ref: type[source]#
Returns the type of the associated Module/Initializer.
NOTE: It is recommended to set the __class_ref__ to the name of the associated module/initializer when defining custom configuration classes. The automatic class_ref solver is extremely brittle and likely to fail in many different custom scenarios.
- Return type:
- classmethod from_dict(dct)[source]#
Create config instance from dictionary.
- Parameters:
dct (dict)
- Return type:
- classmethod from_file(file_path, is_partial=False)[source]#
Create config instance from a .scfg file.
- Parameters:
- Return type:
- __iter__()[source]#
Custom iterator to simplify SparkConfig inspection across the entire ecosystem. This iterator excludes private fields.
- Output:
field_name: str, field name field_value: tp.Any, field value
- Return type:
Iterator[tuple[str, dataclasses.Field, Any]]
- class spark.nn.Brain(config=None, **kwargs)[source]#
Bases:
spark.core.module.SparkModuleBrain model.
A brain is a pipeline object used to represent and coordinate a collection of neurons and interfaces. This implementation relies on a cache system to simplify parallel computations; every timestep all the modules in the Brain read from the cache, update its internal state and update the cache state. Note that this introduces a small latency between elements of the brain, which for most cases is negligible, and for such a reason it is recommended that only full neuron models and interfaces are used within a Brain.
- Parameters:
config (BrainConfig)
- config: BrainConfig[source]#
- build(input_specs)[source]#
Build method.
- Parameters:
input_specs (dict[str, spark.core.specs.PortSpecs])
- __call__(**inputs)[source]#
Update brain’s states.
- Parameters:
inputs (spark.core.payloads.SparkPayload)
- Return type:
- class spark.nn.BrainConfig(__skip_validation__=False, **kwargs)[source]#
Bases:
spark.core.config.BaseSparkConfigConfiguration class for Brain’s.
- Parameters:
__skip_validation__ (bool)
- input_map: dict[str, spark.core.specs.PortSpecs][source]#
- modules_map: dict[str, spark.core.specs.ModuleSpecs][source]#