spark.core.variables#

Classes#

Constant

Jax.Array wrapper for constant arrays.

Variable

The base class for all Variable types.

Module Contents#

class spark.core.variables.Constant(data, dtype=None)[source]#

Jax.Array wrapper for constant arrays.

Parameters:
  • data (Any)

  • dtype (Any)

value: jax.Array[source]#
__jax_array__()[source]#
Return type:

jax.Array

__array__(dtype=None)[source]#
Return type:

numpy.ndarray

property shape: tuple[int, Ellipsis][source]#
Return type:

tuple[int, Ellipsis]

property dtype: Any[source]#
Return type:

Any

property ndim: int[source]#
Return type:

int

property size: int[source]#
Return type:

int

property T: jax.Array[source]#
Return type:

jax.Array

__neg__()[source]#
Return type:

jax.Array

__pos__()[source]#
Return type:

jax.Array

__abs__()[source]#
Return type:

jax.Array

__invert__()[source]#
Return type:

jax.Array

__add__(other)[source]#
Return type:

jax.Array

__sub__(other)[source]#
Return type:

jax.Array

__mul__(other)[source]#
Return type:

jax.Array

__truediv__(other)[source]#
Return type:

jax.Array

__floordiv__(other)[source]#
Return type:

jax.Array

__mod__(other)[source]#
Return type:

jax.Array

__matmul__(other)[source]#
Return type:

jax.Array

__pow__(other)[source]#
Return type:

jax.Array

__radd__(other)[source]#
Return type:

jax.Array

__rsub__(other)[source]#
Return type:

jax.Array

__rmul__(other)[source]#
Return type:

jax.Array

__rtruediv__(other)[source]#
Return type:

jax.Array

__rfloordiv__(other)[source]#
Return type:

jax.Array

__rmod__(other)[source]#
Return type:

jax.Array

__rmatmul__(other)[source]#
Return type:

jax.Array

__rpow__(other)[source]#
Return type:

jax.Array

class spark.core.variables.Variable(value, dtype=None, **metadata)[source]#

Bases: flax.nnx.Variable

The base class for all Variable types. Note that this is just a convinience wrapper around Flax’s nnx.Variable to simplify imports.

Parameters:
  • value (Any)

  • dtype (Any)

value: jax.Array[source]#
__jax_array__()[source]#
Return type:

jax.Array

__array__(dtype=None)[source]#
Return type:

numpy.ndarray

property shape: tuple[int, Ellipsis][source]#
Return type:

tuple[int, Ellipsis]