Tutorial #1: Modules and Components

Tutorial #1: Modules and Components#

Spark is intended as a research framework. As such one of the first things that one may want is to add a new element to the systems.

Fortunately, adding new components is quite straight forward, although it may be a little confusing the first time. Let’s start from the beginning. In Spark, as in any other major library, every module inherits from spark.nn.Module. This class is at the heart of most of the Spark’s magic. Depending on the specificity of the module you are trying to build it could make sense to subclass other classes:

Name

Purpose

spark.nn.Component

Arbitrary component of a neuron

spark.nn.somas.Soma

Soma models

spark.nn.synapses.Synanpses

Synaptic models

spark.nn.learning_rules.LearningRule

Learning rules

spark.nn.delays.Delays

Spike delay mechanisms

This list is not exhaustive, but it contains the most important subclasses of spark.nn.Module.

REMARK: One important thing to mention before we start is that Python does not strictly enforce typing, however, Spark does rely on typing for a few of our core features. Although some of your code may work without enforcing typing there are a few things that will require it if you like for your components to play nicely with the Spark ecosystem. In general, it is quite straight forward and is only required in a few parts, so if you really despise typing most of your code can still be typeless.

For the time being let’s just create a simple module.

import sys
sys.path.insert(1, './..')

# Imports
import spark
import typing as tp
import jax.numpy as jnp

One of the first things to notice is new components are defined in pairs of Module - Configurations, where the Module defines the logic and the Configurations defines all the parameters required to initialize that Module. By convention we simply name these pairs as Module - ModuleConfig.

Every time a new module is defined we need to link it to its default configuration. This is done by simply adding config: ConfigClass to the Module class definition.

Another thing to notice at this point is that the signature of all the init methods is common across all modules, it should only accept a configuration class, which by default can be empty and keywords arguments. Similarly, the first thing to do in any __init__ method is to invoke the __init__ method of the super class with the pattern indicated below.

class MyAwesomeModuleConfig(spark.nn.Config):
    pass

class MyAwesomeModule(spark.nn.Module):
    config: MyAwesomeModuleConfig       # <--- Default configuration class MUST always be indicated

    def __init__(self, config: MyAwesomeModuleConfig = None, **kwargs):
        super().__init__(config=config, **kwargs)       # <--- super().__init__ should always be invoked first as follows

At this point you will be wondering where are all your important init arguments go? The answer is the configuration class.

Spark modules separate model definition from model logic. This allow us to do some neat tricks on the back and it is extremely useful for reproducibility. Configuration classes should always be typed, this is done using the notation variable_name: variable_type = default_value.

All the variables defined inside the configuration class are available in the __init__ method, after calling super().__init__, under the namespace self.config.

Another important thing to notice here is that we cannot store arrays directly. Every array, must be properly wrapper within a spark.Constant or a spark.Variable. This wrappers are necessary when we JIT compile the model to let Jax know that some arrays may are mutable and some are simple constants. By default, some base python classes play well with JIT but we highly recomed to wrap everything around a Constant or a Variable according to its role in your model.

class MyAwesomeModuleConfig(spark.nn.Config):
    foo: int
    bar: float = 2.0

class MyAwesomeModule(spark.nn.Module):
    config: MyAwesomeModuleConfig

    def __init__(self, config: MyAwesomeModuleConfig = None, **kwargs):
        super().__init__(config=config, **kwargs)
        #self.foo = jnp.array(self.config.foo) <--- # Will throw an error later when we start with JIT!.
        self.foo = spark.Constant(jnp.array(self.config.foo))
        self.bar = spark.Variable(jnp.array(self.config.bar))

Now, spark.nn.Module is an abstract class and as such it cannot be instantiated if some methods are not defined first. In this case we only need the __init__ and the __call__ methods. However, note that depending on the particular subclass of Module other functions may be necessary.

Apart from the configuration class, __call__ is the other strongly typed element in Spark. The first thing to notice is that __call__ does NOT accept positional arguments only keyword arguments (you can still pass then by position if you want but this is not encouraged as we see later on the JIT section). This keyword arguments must always be typed and the type must always inherit from spark.SparkPayload. Payloads are just wrappers around arrays that helps the internal machinery to know what can connect with what. Every array withing any default payload can be access via my_payload.value.

Next, __call__ must always specify what it returns. This is done through the arrow indication -> at the end of the __call__. The return type is always a TypedDict, that defines the names and the types of each variable that you intend to return after the __call__. Note that, again, all return types must inherit from spark.SparkPayload.

Finally, the return of the __call__ is a dictionary that contains all the variables that you specified in TypedDict, with their respective format. This few type hints really provide the internal machinery with guides on what to do under certain circumstances.

And this is all typing that you need to do!. No more typing after this if you do not like it!.

class MyAwesomeOutput(tp.TypedDict):
    my_awesome_output: spark.FloatArray

class MyAwesomeModuleConfig(spark.nn.Config):
    foo: int
    bar: float = 2.0

class MyAwesomeModule(spark.nn.Module):
    config: MyAwesomeModuleConfig

    def __init__(self, config: MyAwesomeModuleConfig = None, **kwargs):
        super().__init__(config=config, **kwargs)
        self.foo = spark.Constant(jnp.array(self.config.foo))
        self.bar = spark.Variable(jnp.array(self.config.bar))

    def __call__(self, my_awesome_input: spark.FloatArray) -> MyAwesomeOutput:
        awesome_output = self.foo + self.bar + my_awesome_input
        return {
            'my_awesome_output': spark.FloatArray(awesome_output)
        }

Now, Spark, similar to other modern machine learning frameworks, implements lazy initialization. This allows you to build your until you know the other information about the inputs it receives.

To access this feature you need to define the build method. This method takes input_specs as argument, which is a dictionary containing specifications of all input variables, like shape and payload type.

Note that you can define any other method as any other python object and it will still play nicely with Spark.

class MyAwesomeOutput(tp.TypedDict):
	my_awesome_output: spark.FloatArray

class MyAwesomeModuleConfig(spark.nn.Config):
	foo: int
	bar: float = 2.0

class MyAwesomeModule(spark.nn.Module):
	config: MyAwesomeModuleConfig

	def __init__(self, config: MyAwesomeModuleConfig = None, **kwargs):
		super().__init__(config=config, **kwargs)
		self.foo = spark.Constant(jnp.array(self.config.foo))

	def build(self, input_specs: dict[str, spark.InputSpec]):
		mai_spec = input_specs['my_awesome_input']
		self.bar = spark.Variable(
			mai_spec.payload_type( self.config.bar * jnp.ones(mai_spec.shape) )
		)

	def boring_non_typed_function(self, a, b, c):
		return a + b + c

	def __call__(self, my_awesome_input: spark.FloatArray) -> MyAwesomeOutput:
		awesome_output = self.boring_non_typed_function(self.foo, self.bar, my_awesome_input)
		return {
			'my_awesome_output': spark.FloatArray(awesome_output)
		}

This is everything you need to know to build a custom module.

At this point you may be thinking that this pairs of Module - Configuration is going to be really annoying to manage and you could not be more wrong. Informaly, the configuration class is more like a specification than a component of the module. There are several ways to initialize a module.

  1. Passing a keyword arguments that do not provide a default inside the configuration.

  2. Initialize the configuration and feeding it into the module. Note that it must be provided using the keyword “config”.

  3. A mixture of both. You can pass some configuration and keyword arguments. Note that in this case, keyworded arguments take preference over variables defined in the configuration class.

# Method 1
awesome = MyAwesomeModule(foo = 1)
my_awesome_input = spark.FloatArray(jnp.array(1))
res = awesome(my_awesome_input=my_awesome_input)
print(f'Method 1\n {res['my_awesome_output'].value}\n')

# Method 2
awesome_config = MyAwesomeModuleConfig(foo = 1)
awesome = MyAwesomeModule(config=awesome_config)
my_awesome_input = spark.FloatArray(jnp.arange(5))
res = awesome(my_awesome_input=my_awesome_input)
print(f'Method 2\n {res['my_awesome_output'].value}\n')

# Method 3
awesome_config = MyAwesomeModuleConfig(foo = 1)
awesome = MyAwesomeModule(config=awesome_config, bar=-1)
my_awesome_input = spark.FloatArray(jnp.arange(4).reshape(2,2))
res = awesome(my_awesome_input=my_awesome_input)
print(f'Method 3\n {res['my_awesome_output'].value}\n')

We want to highlight that the configuration class is significantly more powerful that it was shown until this point.

To make full use of the power inside of the configuration class it is highly recomended that you use dc.field; this allows you define significantly more complex configurations and they are particuarly useful to add metadata to variables (Do not lie to me!, you have also meet that horrible programmer that took your code 2 months ago and said that he did not needed to annote the code because he wrote it!.).

Internally, there is one metadata field that have an important functional behaviour:

  1. validators: A list of validators that inherit from spark.validation.ConfigurationValidator that is used to validate the arguments of the configuration class.

Additionally, there are a few other metadata fields that hold some special meaning:

  1. units: Since most of the time, spiking neural network works with parameters that are closely related to real equations/measurements it is ideal to know the expect units. Units are a simple string and its sole purpose is to inform other users of the expected value for some argument.

  2. valid_types: Used for some broadcasting checks.

  3. description: A string description of the purpose of this variable.

import dataclasses as dc

class MyAwesomeModuleConfig(spark.nn.Config):
    foo: int												# <-- You can still mix it with non-fields.
    bar: float = dc.field(
        default = 2.0,
        metadata = {
			'units': 'nA', 									# <-- nano Awesomes.
			'valid_types': tp.Any,							# <-- Valid broadcastable to one of your default types.
			'validators': [			
                spark.validation.TypeValidator,				# <-- Extra logic to validate your configuration.
                spark.validation.PositiveValidator,
			], 		
			'description': 'My awesome bar',				# <-- Text description of this variable.
		}
	)
    baz: list[int] = dc.field(
        default_factory = lambda : [i for i in range(10)]	# <-- Factories are also useful to define variables.
	)
    
try:
	MyAwesomeModuleConfig(foo = 1, bar =-1)
except (TypeError, ValueError) as e:
    print(f'Oh now, something went wrong: {e}')
    
try:
	MyAwesomeModuleConfig(foo = 1, bar = [2.0])
except (TypeError, ValueError) as e:
    print(f'Oh now, something went wrong: {e}')

Another important feature of configuration files is that they can be nested. Initializing configurations that hold childs with arguments that do not define a value is still simple, we provide several ways to initialize this configurations:

  1. Keyworded arguments.

  2. Keyworded dictionaries.

  3. Shared arguments.

  4. Direct initialization.

Note that this also work for module initialization!

import dataclasses as dc
import typing as tp

class ChildConfig(spark.nn.Config):
    foo: int
    bar: float = 2.0
    
class ParentConfig(spark.nn.Config):
    foo: int
    child_bar: ChildConfig
    child_baz: ChildConfig = dc.field(
        default_factory = lambda **kwargs: ChildConfig(**{**{'foo': 1}, **kwargs})	
	)
    
# Method 1
config = ParentConfig(foo=1, child_bar__foo=2, child_baz__foo=3)
print(f'Method 1\n->Parent:\t{config.foo}\n--->Child Bar:\t{config.child_bar.foo}\n--->Child Baz:\t{config.child_baz.foo}\n')

# Method 2
config = ParentConfig(foo=4, child_bar={'foo':5}, child_baz={'foo':6})
print(f'Method 1\n->Parent:\t{config.foo}\n--->Child Bar:\t{config.child_bar.foo}\n--->Child Baz:\t{config.child_baz.foo}\n')

# Method 3
config = ParentConfig(_s_foo=7)
print(f'Method 3\n->Parent:\t{config.foo}\n--->Child Bar:\t{config.child_bar.foo}\n--->Child Baz:\t{config.child_baz.foo}\n')

# Method 4
bar = ChildConfig(foo=9)
baz = ChildConfig(foo=10)
config = ParentConfig(foo=8, child_bar=bar, child_baz=baz)
print(f'Method 4\n->Parent:\t{config.foo}\n--->Child Bar:\t{config.child_bar.foo}\n--->Child Baz:\t{config.child_baz.foo}\n')

Finally, in order to let spark know about your awesome Module there is one final step: register your module. This is achieved with a simple decorator. This simple step will allow spark to discover your module and use it more robustly. Note this step is optional but required if you want to use your modules within the Graph Editor.

# Final module is register as 'my_awesome_module'
@spark.register_module
class MyAwesomeModule(spark.nn.Module):
    pass

Alternatively, you can set a specific name to the module. This is useful when you encounter a name conflict.

# Final module is register as 'my_better_awesome_module'
@spark.register_module('my_even_more_awesome_module')
class MyAwesomeModule(spark.nn.Module):
    pass

Another feature that sometimes is useful is obtaining the class reference directly from the configuration class. This feature is rather useful when you are defining templates of modules. For example, you may want to test whether a one type of synapse is better in a particular scenario, you may not want to define two different neuron models, to swap two set of synapses; in many scenarios, simply swapping configurations will do the trick.

If you are following the naming convention of Module - ModuleConfig, then everything is already settle, it is a simple matter of calling config.class_ref. If this naming convention is not of your liking or you are defining custom paths in the registry then you need to set up the reference manually. This is done fairly easy, just add the registry name of your module to the configuration under __class_ref__.

class MyAwesomeModuleNotFollowingConvention(spark.nn.Config):
    __class_ref__ = 'my_even_more_awesome_module'

This should be everything you need to know to start creating new modules!