spark.nn.initializers.base#

Attributes#

Classes#

InitializerConfig

Base initializers configuration class.

Initializer

Base (abstract) class for all Spark initializers.

Module Contents#

class spark.nn.initializers.base.InitializerConfig(__skip_validation__=False, **kwargs)[source]#

Bases: spark.core.config.BaseSparkConfig, abc.ABC

Base initializers configuration class.

Parameters:

__skip_validation__ (bool)

dtype: jax.typing.DTypeLike[source]#
scale: int | float[source]#
min_value: int | float | None[source]#
max_value: int | float | None[source]#
spark.nn.initializers.base.ConfigT[source]#
class spark.nn.initializers.base.Initializer(*, config=None, **kwargs)[source]#

Bases: abc.ABC

Base (abstract) class for all Spark initializers.

Parameters:

config (ConfigT | None)

config: InitializerConfig[source]#
default_config: type[ConfigT][source]#
classmethod __init_subclass__(**kwargs)[source]#
Return type:

None

classmethod get_config_spec()[source]#

Returns the default configuration class associated with this module.

Return type:

type[InitializerConfig]

abstractmethod __call__(key, shape)[source]#
Parameters:
Return type:

jax.Array