spark.core.flax_imports#

Attributes#

A

Functions#

data(value, /)

grad([f, argnums, has_aux, holomorphic, allow_int, ...])

Wrapper around flax.nnx.grad to simply imports.

jit([fun, in_shardings, out_shardings, ...])

Wrapper around flax.nnx.jit to simply imports.

eval_shape(f, *args, **kwargs)

Wrapper around flax.nnx.eval_shape to simply imports.

split(node, *filters)

Wrapper around flax.nnx.split to simply imports.

merge(graphdef, state, /, *states)

Wrapper around flax.nnx.merge to simply imports.

Module Contents#

spark.core.flax_imports.A[source]#
spark.core.flax_imports.data(value, /)[source]#
Parameters:

value (A)

Return type:

A

spark.core.flax_imports.grad(f=MISSING, *, argnums=0, has_aux=False, holomorphic=False, allow_int=False, reduce_axes=())[source]#

Wrapper around flax.nnx.grad to simply imports.

Parameters:
  • f (Callable[Ellipsis, Any] | flax.typing.Missing)

  • argnums (int | flax.nnx.transforms.autodiff.DiffState | Sequence[int | flax.nnx.transforms.autodiff.DiffState])

  • has_aux (bool)

  • holomorphic (bool)

  • allow_int (bool)

  • reduce_axes (Sequence[flax.nnx.transforms.autodiff.AxisName])

Return type:

Callable[Ellipsis, Any] | Callable[[Callable[Ellipsis, Any]], Callable[Ellipsis, Any]]

spark.core.flax_imports.jit(fun=Missing, *, in_shardings=None, out_shardings=None, static_argnums=None, static_argnames=None, donate_argnums=None, donate_argnames=None, keep_unused=False, device=None, backend=None, inline=False, abstracted_axes=None)[source]#

Wrapper around flax.nnx.jit to simply imports.

Parameters:
  • fun (Callable[Ellipsis, Any] | type[flax.typing.Missing])

  • in_shardings (Any)

  • out_shardings (Any)

  • static_argnums (int | Sequence[int] | None)

  • static_argnames (str | Iterable[str] | None)

  • donate_argnums (int | Sequence[int] | None)

  • donate_argnames (str | Iterable[str] | None)

  • keep_unused (bool)

  • device (jax.Device | None)

  • backend (str | None)

  • inline (bool)

  • abstracted_axes (Any | None)

Return type:

jax._src.pjit.JitWrapped | Callable[[Callable[Ellipsis, Any]], jax._src.pjit.JitWrapped]

spark.core.flax_imports.eval_shape(f, *args, **kwargs)[source]#

Wrapper around flax.nnx.eval_shape to simply imports.

Parameters:
  • f (Callable[Ellipsis, A])

  • args (Any)

  • kwargs (Any)

Return type:

A

spark.core.flax_imports.split(node, *filters)[source]#

Wrapper around flax.nnx.split to simply imports.

Parameters:
  • node (A)

  • filters (flax.nnx.filterlib.Filter)

Return type:

tuple[flax.nnx.graph.GraphDef[A], flax.nnx.graph.GraphState | flax.nnx.variablelib.VariableState, typing_extensions.Unpack[tuple[flax.nnx.graph.GraphState | flax.nnx.variablelib.VariableState, Ellipsis]]]

spark.core.flax_imports.merge(graphdef, state, /, *states)[source]#

Wrapper around flax.nnx.merge to simply imports.

Parameters:
Return type:

A