spark.core.module#

Attributes#

Classes#

ModuleOutput

Spark module output template

SparkMeta

Metaclass for Spark Modules.

SparkModule

Base class for Spark Modules

Module Contents#

spark.core.module.ConfigT[source]#
spark.core.module.InputT[source]#
class spark.core.module.ModuleOutput[source]#

Bases: TypedDict

Spark module output template

Initialize self. See help(type(self)) for accurate signature.

class spark.core.module.SparkMeta[source]#

Bases: flax.nnx.module.ModuleMeta

Metaclass for Spark Modules.

class spark.core.module.SparkModule(*, config=None, name=None, **kwargs)[source]#

Bases: flax.nnx.Module, abc.ABC, Generic[ConfigT, InputT]

Base class for Spark Modules

Parameters:
  • config (ConfigT | None)

  • name (str | None)

name: str = 'name'[source]#
config: ConfigT[source]#
default_config: type[ConfigT][source]#
classmethod __init_subclass__(**kwargs)[source]#
__built__: bool = False[source]#
__allow_cycles__: bool = False[source]#
classmethod get_config_spec()[source]#

Returns the default configuratio class associated with this module.

Return type:

type[spark.core.config.BaseSparkConfig]

build(input_specs)[source]#

Build method.

Parameters:

input_specs (dict[str, spark.core.specs.InputSpec])

Return type:

None

reset()[source]#

Reset module to its default state.

set_recurrent_shape_contract(shape=None, output_shapes=None)[source]#

Recurrent shape policy pre-defines expected shapes for the output specs.

This is function is a binding contract that allows the modules to accept self connections.

Input:

shape: tuple[int, …], A common shape for all the outputs. output_shapes: dict[str, tuple[int, …]], A specific policy for every single output variable.

NOTE: If both, shape and output_specs, are provided, output_specs takes preference over shape.

Parameters:
Return type:

None

get_recurrent_shape_contract()[source]#

Retrieve the recurrent shape policy of the module.

get_input_specs()[source]#

Returns a dictionary of the SparkModule’s input port specifications.

Return type:

dict[str, spark.core.specs.InputSpec]

get_output_specs()[source]#

Returns a dictionary of the SparkModule’s input port specifications.

Return type:

dict[str, spark.core.specs.OutputSpec]

get_rng_keys(num_keys)[source]#

Generates a new collection of random keys for the JAX’s random engine.

Parameters:

num_keys (int)

Return type:

jax.Array | list[jax.Array]

abstractmethod __call__(**kwargs)[source]#

Execution method.

Parameters:

kwargs (InputT)

Return type:

ModuleOutput