spark.core.flax_imports#
Attributes#
Functions#
| 
 | |
| 
 | Wrapper around flax.nnx.grad to simply imports. | 
| 
 | Wrapper around flax.nnx.jit to simply imports. | 
| 
 | Wrapper around flax.nnx.eval_shape to simply imports. | 
| 
 | Wrapper around flax.nnx.split to simply imports. | 
| 
 | Wrapper around flax.nnx.merge to simply imports. | 
Module Contents#
- spark.core.flax_imports.grad(f=MISSING, *, argnums=0, has_aux=False, holomorphic=False, allow_int=False, reduce_axes=())[source]#
- Wrapper around flax.nnx.grad to simply imports. - Parameters:
- Return type:
- Callable[Ellipsis, Any] | Callable[[Callable[Ellipsis, Any]], Callable[Ellipsis, Any]] 
 
- spark.core.flax_imports.jit(fun=Missing, *, in_shardings=None, out_shardings=None, static_argnums=None, static_argnames=None, donate_argnums=None, donate_argnames=None, keep_unused=False, device=None, backend=None, inline=False, abstracted_axes=None)[source]#
- Wrapper around flax.nnx.jit to simply imports. - Parameters:
- fun (Callable[Ellipsis, Any] | type[flax.typing.Missing]) 
- in_shardings (Any) 
- out_shardings (Any) 
- keep_unused (bool) 
- device (jax.Device | None) 
- backend (str | None) 
- inline (bool) 
- abstracted_axes (Any | None) 
 
- Return type:
- jax._src.pjit.JitWrapped | Callable[[Callable[Ellipsis, Any]], jax._src.pjit.JitWrapped] 
 
- spark.core.flax_imports.eval_shape(f, *args, **kwargs)[source]#
- Wrapper around flax.nnx.eval_shape to simply imports. - Parameters:
- f (Callable[Ellipsis, A]) 
- args (Any) 
- kwargs (Any) 
 
- Return type:
- A 
 
- spark.core.flax_imports.split(node, *filters)[source]#
- Wrapper around flax.nnx.split to simply imports. - Parameters:
- node (A) 
- filters (flax.nnx.filterlib.Filter) 
 
- Return type:
- tuple[flax.nnx.graph.GraphDef[A], flax.nnx.graph.GraphState | flax.nnx.variablelib.VariableState, typing_extensions.Unpack[tuple[flax.nnx.graph.GraphState | flax.nnx.variablelib.VariableState, Ellipsis]]] 
 
- spark.core.flax_imports.merge(graphdef, state, /, *states)[source]#
- Wrapper around flax.nnx.merge to simply imports. - Parameters:
- graphdef (flax.nnx.graph.GraphDef[A]) 
- state (Any) 
- states (Any) 
 
- Return type:
- A 
 
