spark.core.flax_imports#
Attributes#
Functions#
|
|
|
Wrapper around flax.nnx.grad to simply imports. |
|
Wrapper around flax.nnx.jit to simply imports. |
|
Wrapper around flax.nnx.eval_shape to simply imports. |
|
Wrapper around flax.nnx.split to simply imports. |
|
Wrapper around flax.nnx.merge to simply imports. |
Module Contents#
- spark.core.flax_imports.grad(f=MISSING, *, argnums=0, has_aux=False, holomorphic=False, allow_int=False, reduce_axes=())[source]#
Wrapper around flax.nnx.grad to simply imports.
- Parameters:
- Return type:
Callable[Ellipsis, Any] | Callable[[Callable[Ellipsis, Any]], Callable[Ellipsis, Any]]
- spark.core.flax_imports.jit(fun=Missing, *, in_shardings=None, out_shardings=None, static_argnums=None, static_argnames=None, donate_argnums=None, donate_argnames=None, keep_unused=False, device=None, backend=None, inline=False, abstracted_axes=None)[source]#
Wrapper around flax.nnx.jit to simply imports.
- Parameters:
fun (Callable[Ellipsis, Any] | type[flax.typing.Missing])
in_shardings (Any)
out_shardings (Any)
keep_unused (bool)
device (jax.Device | None)
backend (str | None)
inline (bool)
abstracted_axes (Any | None)
- Return type:
jax._src.pjit.JitWrapped | Callable[[Callable[Ellipsis, Any]], jax._src.pjit.JitWrapped]
- spark.core.flax_imports.eval_shape(f, *args, **kwargs)[source]#
Wrapper around flax.nnx.eval_shape to simply imports.
- Parameters:
f (Callable[Ellipsis, A])
args (Any)
kwargs (Any)
- Return type:
A
- spark.core.flax_imports.split(node, *filters)[source]#
Wrapper around flax.nnx.split to simply imports.
- Parameters:
node (A)
filters (flax.nnx.filterlib.Filter)
- Return type:
tuple[flax.nnx.graph.GraphDef[A], flax.nnx.graph.GraphState | flax.nnx.variablelib.VariableState, typing_extensions.Unpack[tuple[flax.nnx.graph.GraphState | flax.nnx.variablelib.VariableState, Ellipsis]]]
- spark.core.flax_imports.merge(graphdef, state, /, *states)[source]#
Wrapper around flax.nnx.merge to simply imports.
- Parameters:
graphdef (flax.nnx.graph.GraphDef[A])
state (Any)
states (Any)
- Return type:
A