spark.nn#
Submodules#
Classes#
Base class for Spark Modules |
|
Default class for module configuration. |
|
Base class for module configuration. |
|
Abstract brain model. |
|
Configuration class for Brain's. |
Package Contents#
- class spark.nn.Module(*, config=None, name=None, **kwargs)[source]#
Bases:
flax.nnx.Module,abc.ABC,Generic[ConfigT,InputT]Base class for Spark Modules
- Parameters:
config (ConfigT | None)
name (str | None)
- classmethod get_config_spec()[source]#
Returns the default configuratio class associated with this module.
- Return type:
- build(input_specs)[source]#
Build method.
- Parameters:
input_specs (dict[str, spark.core.specs.InputSpec])
- Return type:
None
- set_recurrent_shape_contract(shape=None, output_shapes=None)[source]#
Recurrent shape policy pre-defines expected shapes for the output specs.
This is function is a binding contract that allows the modules to accept self connections.
- Input:
shape: tuple[int, …], A common shape for all the outputs. output_shapes: dict[str, tuple[int, …]], A specific policy for every single output variable.
NOTE: If both, shape and output_specs, are provided, output_specs takes preference over shape.
- get_input_specs()[source]#
Returns a dictionary of the SparkModule’s input port specifications.
- Return type:
- get_output_specs()[source]#
Returns a dictionary of the SparkModule’s input port specifications.
- Return type:
- get_rng_keys(num_keys)[source]#
Generates a new collection of random keys for the JAX’s random engine.
- class spark.nn.Config(**kwargs)[source]#
Bases:
BaseSparkConfigDefault class for module configuration.
- class spark.nn.BaseConfig(**kwargs)[source]#
Bases:
abc.ABCBase class for module configuration.
- diff(other)[source]#
Return differences from another config.
- Parameters:
other (BaseSparkConfig)
- Return type:
- get_metadata()[source]#
Returns all the metadata in the configuration class, indexed by the attribute name.
- property class_ref: type[source]#
Returns the type of the associated Module/Initializer.
NOTE: It is recommended to set the __class_ref__ to the name of the associated module/initializer when defining custom configuration classes. The current class_ref solver is extremely brittle and likely to fail in many different custom scenarios.
- Return type:
- classmethod from_dict(dct)[source]#
Create config instance from dictionary.
- Parameters:
dct (dict)
- Return type:
- to_file(file_path)[source]#
Export a config instance from a .scfg file.
- Parameters:
file_path (str)
- Return type:
None
- classmethod from_file(file_path)[source]#
Create config instance from a .scfg file.
- Parameters:
file_path (str)
- Return type:
None
- __iter__()[source]#
Custom iterator to simplify SparkConfig inspection across the entire ecosystem. This iterator excludes private fields.
- Output:
field_name: str, field name field_value: tp.Any, field value
- Return type:
Iterator[tuple[str, dataclasses.Field, Any]]
- class spark.nn.Brain(config=None, **kwargs)[source]#
Bases:
spark.core.module.SparkModuleAbstract brain model. This is more a convenience class used to synchronize data more easily.
- Parameters:
config (BrainConfig)
- config: BrainConfig[source]#
- build(input_specs)[source]#
Build method.
- Parameters:
input_specs (dict[str, spark.core.specs.InputSpec])
- __call__(**inputs)[source]#
Update brain’s states.
- Parameters:
inputs (spark.core.payloads.SparkPayload)
- Return type:
- class spark.nn.BrainConfig(**kwargs)[source]#
Bases:
spark.core.config.BaseSparkConfigConfiguration class for Brain’s.
- input_map: dict[str, spark.core.specs.InputSpec][source]#
- modules_map: dict[str, spark.core.specs.ModuleSpecs][source]#