spark.core.module#
Attributes#
Classes#
Spark module output template |
|
Metaclass for Spark Modules. |
|
Base class for Spark Modules |
Module Contents#
- class spark.core.module.ModuleOutput[source]#
Bases:
TypedDictSpark module output template
Initialize self. See help(type(self)) for accurate signature.
- class spark.core.module.SparkMeta[source]#
Bases:
flax.nnx.module.ModuleMetaMetaclass for Spark Modules.
- class spark.core.module.SparkModule(*, config=None, name=None, **kwargs)[source]#
Bases:
flax.nnx.Module,abc.ABC,Generic[ConfigT,InputT]Base class for Spark Modules
- Parameters:
config (ConfigT | None)
name (str | None)
- classmethod get_config_spec()[source]#
Returns the default configuratio class associated with this module.
- Return type:
- build(input_specs)[source]#
Build method.
- Parameters:
input_specs (dict[str, spark.core.specs.InputSpec])
- Return type:
None
- set_recurrent_shape_contract(shape=None, output_shapes=None)[source]#
Recurrent shape policy pre-defines expected shapes for the output specs.
This is function is a binding contract that allows the modules to accept self connections.
- Input:
shape: tuple[int, …], A common shape for all the outputs. output_shapes: dict[str, tuple[int, …]], A specific policy for every single output variable.
NOTE: If both, shape and output_specs, are provided, output_specs takes preference over shape.
- get_input_specs()[source]#
Returns a dictionary of the SparkModule’s input port specifications.
- Return type:
- get_output_specs()[source]#
Returns a dictionary of the SparkModule’s input port specifications.
- Return type:
- get_rng_keys(num_keys)[source]#
Generates a new collection of random keys for the JAX’s random engine.