spark.nn.initializers#

Submodules#

Classes#

Initializer

Base (abstract) class for all Spark initializers.

InitializerConfig

Base initializers configuration class.

ConstantInitializer

Initializer that returns real uniformly-distributed random arrays.

ConstantInitializerConfig

ConstantInitializer configuration class.

UniformInitializer

Initializer that returns real uniformly-distributed random arrays.

UniformInitializerConfig

UniformInitializer configuration class.

SparseUniformInitializer

Initializer that returns a real sparse uniformly-distributed random arrays.

SparseUniformInitializerConfig

SparseUniformInitializer configuration class.

NormalizedSparseUniformInitializer

Initializer that returns a real sparse uniformly-distributed random arrays.

NormalizedSparseUniformInitializerConfig

NormalizedSparseUniformInitializer configuration class.

Package Contents#

class spark.nn.initializers.Initializer(*, config=None, **kwargs)[source]#

Bases: abc.ABC

Base (abstract) class for all Spark initializers.

Parameters:

config (ConfigT | None)

config: InitializerConfig[source]#
default_config: type[ConfigT][source]#
classmethod __init_subclass__(**kwargs)[source]#
Return type:

None

abstractmethod __call__(key, shape)[source]#
Parameters:
Return type:

jax.Array

class spark.nn.initializers.InitializerConfig(**kwargs)[source]#

Bases: spark.core.config.BaseSparkConfig

Base initializers configuration class.

dtype: jax.typing.DTypeLike[source]#
scale: T[source]#
min_value: T | None[source]#
max_value: T | None[source]#
class spark.nn.initializers.ConstantInitializer(*, config=None, **kwargs)[source]#

Bases: spark.nn.initializers.base.Initializer

Initializer that returns real uniformly-distributed random arrays.

Init:

scale: numeric, value for the output array (default = 1).

Input:

key: jax.Array, key for the random generator (jax.random.key). shape: tuple[int, …],shaoe fir the output array.

Parameters:

config (ConfigT | None)

config: ConstantInitializerConfig[source]#
__call__(key, shape)[source]#
Parameters:
Return type:

jax.Array

class spark.nn.initializers.ConstantInitializerConfig(**kwargs)[source]#

Bases: spark.nn.initializers.base.InitializerConfig

ConstantInitializer configuration class.

__class_ref__: str = 'ConstantInitializer'[source]#
class spark.nn.initializers.UniformInitializer(*, config=None, **kwargs)[source]#

Bases: spark.nn.initializers.base.Initializer

Initializer that returns real uniformly-distributed random arrays.

Init:

scale: numeric, multiplicative factor for the output array (default = 1). min_value: numeric, minimum value for the output array (default = None). max_value: numeric, maximum value for the output array (default = None).

Input:

key: jax.Array, key for the random generator (jax.random.key). shape: tuple[int, …],shaoe fir the output array.

Parameters:

config (ConfigT | None)

config: UniformInitializerConfig[source]#
__call__(key, shape)[source]#
Parameters:
Return type:

jax.Array

class spark.nn.initializers.UniformInitializerConfig(**kwargs)[source]#

Bases: spark.nn.initializers.base.InitializerConfig

UniformInitializer configuration class.

__class_ref__: str = 'UniformInitializer'[source]#
class spark.nn.initializers.SparseUniformInitializer(*, config=None, **kwargs)[source]#

Bases: UniformInitializer

Initializer that returns a real sparse uniformly-distributed random arrays.

Note that the output will contain zero values even if min_value > 0.

Init:

scale: numeric, multiplicative factor for the output array (default = 1). min_value: numeric, minimum value for the output array (default = None). max_value: numeric, maximum value for the output array (default = None). density: float, expected ratio of non-zero entries (default = 0.2).

Input:

key: jax.Array, key for the random generator (jax.random.key). shape: tuple[int, …],shaoe fir the output array.

Parameters:

config (ConfigT | None)

config: SparseUniformInitializerConfig[source]#
__call__(key, shape)[source]#
Parameters:
Return type:

jax.Array

class spark.nn.initializers.SparseUniformInitializerConfig(**kwargs)[source]#

Bases: UniformInitializerConfig

SparseUniformInitializer configuration class.

__class_ref__: str = 'SparseUniformInitializer'[source]#
density: float[source]#
class spark.nn.initializers.NormalizedSparseUniformInitializer(*, config=None, **kwargs)[source]#

Bases: SparseUniformInitializer

Initializer that returns a real sparse uniformly-distributed random arrays. This is a variation of the SparseUniformInitializer that normalizes the array, which may be useful to prevent quiescent neurons. Entries in the array are normalized by contracting the array over to the norm_axes and rescaled back to [min_value, max_value].

Normalization example array -> ijk; norm_axes -> (i,k) contraction = ‘ijk->ik’ sum(norm_array[i,:,k]) = 1

Note that the output will contain zero values even if min_value > 0.

Init:

scale: numeric, multiplicative factor for the output array (default = 1). min_value: numeric, minimum value for the output array (default = None). max_value: numeric, maximum value for the output array (default = None). density: float, expected ratio of non-zero entries (default = 0.2). norm_axes: tuple[int, …], axes used for normalization (default = (0,)):

Input:

key: jax.Array, key for the random generator (jax.random.key). shape: tuple[int, …], shape for the output array.

Output:

jax.Array[dtype]

Parameters:

config (ConfigT | None)

config: NormalizedSparseUniformInitializerConfig[source]#
__call__(key, shape)[source]#
Parameters:
Return type:

jax.Array

class spark.nn.initializers.NormalizedSparseUniformInitializerConfig(**kwargs)[source]#

Bases: SparseUniformInitializerConfig

NormalizedSparseUniformInitializer configuration class.

__class_ref__: str = 'NormalizedSparseUniformInitializer'[source]#
norm_axes: tuple[int, Ellipsis] | None[source]#