spark.nn.initializers#
Submodules#
Classes#
Base (abstract) class for all Spark initializers. |
|
Base initializers configuration class. |
|
Initializer that returns real uniformly-distributed random arrays. |
|
ConstantInitializer configuration class. |
|
Initializer that returns real uniformly-distributed random arrays. |
|
UniformInitializer configuration class. |
|
Initializer that returns a real sparse uniformly-distributed random arrays. |
|
SparseUniformInitializer configuration class. |
|
Initializer that returns a real sparse uniformly-distributed random arrays. |
|
NormalizedSparseUniformInitializer configuration class. |
Package Contents#
- class spark.nn.initializers.Initializer(*, config=None, **kwargs)[source]#
Bases:
abc.ABCBase (abstract) class for all Spark initializers.
- Parameters:
config (ConfigT | None)
- config: InitializerConfig[source]#
- class spark.nn.initializers.InitializerConfig(**kwargs)[source]#
Bases:
spark.core.config.BaseSparkConfigBase initializers configuration class.
- class spark.nn.initializers.ConstantInitializer(*, config=None, **kwargs)[source]#
Bases:
spark.nn.initializers.base.InitializerInitializer that returns real uniformly-distributed random arrays.
- Init:
scale: numeric, value for the output array (default = 1).
- Input:
key: jax.Array, key for the random generator (jax.random.key). shape: tuple[int, …],shaoe fir the output array.
- Parameters:
config (ConfigT | None)
- class spark.nn.initializers.ConstantInitializerConfig(**kwargs)[source]#
Bases:
spark.nn.initializers.base.InitializerConfigConstantInitializer configuration class.
- class spark.nn.initializers.UniformInitializer(*, config=None, **kwargs)[source]#
Bases:
spark.nn.initializers.base.InitializerInitializer that returns real uniformly-distributed random arrays.
- Init:
scale: numeric, multiplicative factor for the output array (default = 1). min_value: numeric, minimum value for the output array (default = None). max_value: numeric, maximum value for the output array (default = None).
- Input:
key: jax.Array, key for the random generator (jax.random.key). shape: tuple[int, …],shaoe fir the output array.
- Parameters:
config (ConfigT | None)
- class spark.nn.initializers.UniformInitializerConfig(**kwargs)[source]#
Bases:
spark.nn.initializers.base.InitializerConfigUniformInitializer configuration class.
- class spark.nn.initializers.SparseUniformInitializer(*, config=None, **kwargs)[source]#
Bases:
UniformInitializerInitializer that returns a real sparse uniformly-distributed random arrays.
Note that the output will contain zero values even if min_value > 0.
- Init:
scale: numeric, multiplicative factor for the output array (default = 1). min_value: numeric, minimum value for the output array (default = None). max_value: numeric, maximum value for the output array (default = None). density: float, expected ratio of non-zero entries (default = 0.2).
- Input:
key: jax.Array, key for the random generator (jax.random.key). shape: tuple[int, …],shaoe fir the output array.
- Parameters:
config (ConfigT | None)
- class spark.nn.initializers.SparseUniformInitializerConfig(**kwargs)[source]#
Bases:
UniformInitializerConfigSparseUniformInitializer configuration class.
- class spark.nn.initializers.NormalizedSparseUniformInitializer(*, config=None, **kwargs)[source]#
Bases:
SparseUniformInitializerInitializer that returns a real sparse uniformly-distributed random arrays. This is a variation of the SparseUniformInitializer that normalizes the array, which may be useful to prevent quiescent neurons. Entries in the array are normalized by contracting the array over to the norm_axes and rescaled back to [min_value, max_value].
Normalization example array -> ijk; norm_axes -> (i,k) contraction = ‘ijk->ik’ sum(norm_array[i,:,k]) = 1
Note that the output will contain zero values even if min_value > 0.
- Init:
scale: numeric, multiplicative factor for the output array (default = 1). min_value: numeric, minimum value for the output array (default = None). max_value: numeric, maximum value for the output array (default = None). density: float, expected ratio of non-zero entries (default = 0.2). norm_axes: tuple[int, …], axes used for normalization (default = (0,)):
- Input:
key: jax.Array, key for the random generator (jax.random.key). shape: tuple[int, …], shape for the output array.
- Output:
jax.Array[dtype]
- Parameters:
config (ConfigT | None)