Tutorial #2: Execution

Tutorial #2: Execution#

Spark, is built on top of Jax and the one of the core features of Jax is Just-In-Time compilation, or JIT for short. JIT is extremely good at optimizing how your code executes. Therefore, it is important to have a minimum understanding of how to play with JIT. So, let’s start with a simple JIT example.

import jax
import jax.numpy as jnp

# Some really cool function
def my_awesome_function(x, y):
    return x * y + x**2

# Two random arrays
rng = jax.random.key(42)
key_x, key_y = jax.random.split(rng, 2)
x = jax.random.uniform(key_x, shape=(4096,4096), dtype=jnp.float32)
y = jax.random.uniform(key_y, shape=(4096,4096), dtype=jnp.float32)

# Test it!
%timeit my_awesome_function(x, y)
1.41 ms ± 308 ns per loop (mean ± std. dev. of 7 runs, 1,000 loops each)

Now, this is fast, but can it be better?

Let’s JIT the funtion and run it again.

# Some really cool and fast function
@jax.jit
def my_awesome_function(x, y):
    return x * y + x**2

%timeit my_awesome_function(x, y)
528 μs ± 302 ns per loop (mean ± std. dev. of 7 runs, 1,000 loops each)

Note that, JIT as its name suggest, compiles the code just before running it, this leads to longer execution times the first the function is executed. Depending on the code being JITted this can lead to really long compilation times, for example, it is possible to compile and entire for loop that executes thousands of operations in total (note that this is not recommended but may be a useful trick depending on the circumstances).

Unfortunately, it is slightly more complicated in Spark (but do not worry, it is not that complicated!).

Under the hood Jax unfolds the computation graph and tries to optimize it, this however is not trivial when you have to do some state management. In Spark we borrow one of the core element of Flax, an excellent machine learning library, that allow us to not care to much about state management at the expense of a slightly more complicated JIT execution. Although this may change as Spark matures, right now this is the standard approach to execute your code.

Let’s start by initializing some simple ALIF model.

# Imports
import spark

# Number of neurons
units = (8,)

# Input shape
input_shape = (16,)

# Model initialization
alif_neurons = spark.nn.neurons.ALIFNeuron(
    units=units,
	max_delay=4,
	inhibitory_rate=0.2,
	_s_async_spikes=True,
    _s_units=units,
)

# Test the model.
rng = jax.random.key(42)
rng, key = jax.random.split(rng, 2)
in_spikes = spark.SpikeArray( 
    jax.random.uniform(key, shape=input_shape, dtype=jnp.float16) < 0.5 
)

# Note that the model the first time is called.
try:
    alif_neurons.soma
    print('Those are some nice somas!.\n')
except:
    print('Oh no!, alif_neurons do not have a soma property.\n')

# First call to alif_neurons 
out_spikes = alif_neurons(in_spikes=in_spikes)
print(out_spikes, '\n')

# Now the soma property exists.
try:
    alif_neurons.soma
    print('Those are some nice somas!.')
except:
    print('Oh no!, alif_neurons do not have a soma property.')
Oh no!, alif_neurons do not have a soma property.

{'out_spikes': SpikeArray(value=Array([ 0.,  0., -0.,  0.,  0.,  0.,  0.,  0.], dtype=float16))} 

Those are some nice somas!.

Similarly, to the cases above, we just need to wrap the execution of the model inside a function that we can then compile. Remember that the first time the function is executed it can take a little bit more time (important if you want to benchmark your own models!).

# JITted function for execution
@jax.jit
def run_model(model, **inputs):
	outputs = model(**inputs)
	return outputs, model		# <-- Return and replace the model. Otherwise the model will not change

# Execution
outputs, alif_neurons = run_model(alif_neurons, in_spikes=in_spikes) # Note that we overwrite the model after the execution.

%timeit run_model(alif_neurons, in_spikes=in_spikes)
126 μs ± 7.46 μs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)

Although we can feed directly our model to JAX and it will execute fine and fast, it is almost always faster to split the model into a graph and a state (quirks of the code ¯\(ツ)/¯).

To separate the logic of our model into a graph and an state we can use spark.split and to glue it back for execution we use spark.merge. Note that currently this is just a convinience wrapper arounds Flax’s nnx.split and nnx.merge, respectively, although it may change in the future. That’s it, this is as much as JIT as you need to know. Just remember, when you define your own models use spark.Constant and spark.Variable to wrap your code, otherwise JIT is going to complain!.

As a small observation, note that we feed everything to the model with **inputs since the __call__ method only allows for keyworded arguments (which can still be received by position). This could be really advantageous when we encounter models that have more than one input stream or we have many different models that we want to try since it allows us to create code that is more reusable and flexible.

# JITted function for execution
@jax.jit
def run_model_split(graph, state, **inputs):
	model = spark.merge(graph, state)	# Merge the graph and the state into a single executable model equivalent to ALIFNeuron.
	outputs = model(**inputs)
	_, state = spark.split((model))		# Split the model back into a graph and a state. Here we are discarding the graph since typically it doesn't change.
	return outputs, state

# Execution
graph, state = spark.split((alif_neurons))	# <-- Note that model is inside another parenthesis, this will be important for more complex cases.
outputs, state = run_model_split(graph, state, in_spikes=in_spikes)	# Note that we overwrite the state after the execution.

%timeit run_model_split(graph, state, in_spikes=in_spikes)
80.8 μs ± 3.18 μs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)

This is basically all the JAX & JIT you need to know to get the most out of Spark!