spark.core.module#
Attributes#
Classes#
Spark module output template |
|
Metaclass for Spark Modules. |
|
Base class for Spark Modules |
Module Contents#
- class spark.core.module.ModuleOutput[source]#
Bases:
TypedDictSpark module output template
Initialize self. See help(type(self)) for accurate signature.
- class spark.core.module.SparkMeta[source]#
Bases:
flax.nnx.module.ModuleMetaMetaclass for Spark Modules.
- class spark.core.module.SparkModule(*, config=None, name=None, **kwargs)[source]#
Bases:
flax.nnx.Module,abc.ABC,Generic[ConfigT,InputT]Base class for Spark Modules
- Parameters:
config (ConfigT | None)
name (str | None)
- classmethod get_config_spec()[source]#
Returns the default configuration class associated with this module.
- Return type:
- build(input_specs)[source]#
Build method.
- Parameters:
input_specs (dict[str, spark.core.specs.PortSpecs])
- Return type:
None
- set_contract_output_specs(contract_specs)[source]#
Recurrent shape policy pre-defines expected shapes for the output specs.
This is function is a binding contract that allows the modules to accept self connections.
- Input:
contract_specs: dict[str, PortSpecs], A dictionary with a contract for the output specs.
NOTE: If both, shape and output_specs, are provided, output_specs takes preference over shape.
- Parameters:
contract_specs (dict[str, spark.core.specs.PortSpecs])
- Return type:
None
- get_contract_output_specs()[source]#
Retrieve the recurrent spec policy of the module.
- Return type:
dict[str, spark.core.specs.PortSpecs] | None
- get_input_specs()[source]#
Returns a dictionary of the SparkModule’s input port specifications.
- Return type:
- get_output_specs()[source]#
Returns a dictionary of the SparkModule’s input port specifications.
- Return type:
- get_rng_keys(num_keys)[source]#
Generates a new collection of random keys for the JAX’s random engine.
- abstractmethod __call__(**kwargs)[source]#
Execution method.
- Parameters:
kwargs (InputT)
- Return type: