spark.nn.components.learning_rules

Contents

spark.nn.components.learning_rules#

Submodules#

Classes#

LearningRule

Abstract learning rule model.

LearningRuleConfig

Abstract learning rule configuration class.

LearningRuleOutput

Generic learning rule model output spec.

ZenkeRule

Zenke plasticy rule model. This model is an extension of the classic Hebbian Rule.

ZenkeRuleConfig

ZenkeRule configuration class.

HebbianRule

Hebbian plasticy rule model.

HebbianRuleConfig

HebbianRule configuration class.

OjaRule

Oja's plasticy rule model.

OjaRuleConfig

Abstract learning rule configuration class.

QuadrupletRule

Quadruplet plasticy rule model.

QuadrupletRuleConfig

QuadrupletRule configuration class.

QuadrupletRuleTensor

Quadruplet plasticy rule model (tensor).

QuadrupletRuleTensorConfig

QuadrupletRuleTensor configuration class.

ThreeFactorHebbianRule

Three-factor Hebbian plasticy rule model.

ThreeFactorHebbianRuleConfig

ThreeFactorHebbianRule configuration class.

Package Contents#

class spark.nn.components.learning_rules.LearningRule(config=None, **kwargs)[source]#

Bases: spark.nn.components.base.Component, Generic[ConfigT]

Abstract learning rule model.

Parameters:

config (ConfigT | None)

class spark.nn.components.learning_rules.LearningRuleConfig(__skip_validation__=False, **kwargs)[source]#

Bases: spark.nn.components.base.ComponentConfig

Abstract learning rule configuration class.

Parameters:

__skip_validation__ (bool)

class spark.nn.components.learning_rules.LearningRuleOutput[source]#

Bases: TypedDict

Generic learning rule model output spec.

Initialize self. See help(type(self)) for accurate signature.

kernel: spark.core.payloads.FloatArray[source]#
class spark.nn.components.learning_rules.ZenkeRule(config=None, **kwargs)[source]#

Bases: spark.nn.components.learning_rules.base.LearningRule

Zenke plasticy rule model. This model is an extension of the classic Hebbian Rule.

Init:

pre_tau: float | jax.Array post_tau: float | jax.Array post_slow_tau: float | jax.Array target_tau: float | jax.Array a: float | jax.Array b: float | jax.Array c: float | jax.Array d: float | jax.Array P: float | jax.Array eta: float | jax.Array

Input:

pre_spikes: SpikeArray post_spikes: SpikeArray kernel: FloatArray

Output:

kernel: FloatArray

Parameters:

config (ZenkeRuleConfig | None)

config: ZenkeRuleConfig[source]#
build(input_specs)[source]#

Build method.

Parameters:

input_specs (dict[str, spark.core.specs.PortSpecs])

reset()[source]#

Resets component state.

Return type:

None

__call__(pre_spikes, post_spikes, kernel)[source]#

Computes and returns the next kernel update.

Parameters:
Return type:

spark.nn.components.learning_rules.base.LearningRuleOutput

class spark.nn.components.learning_rules.ZenkeRuleConfig(__skip_validation__=False, **kwargs)[source]#

Bases: spark.nn.components.learning_rules.base.LearningRuleConfig

ZenkeRule configuration class.

Parameters:

__skip_validation__ (bool)

pre_tau: float | jax.Array | spark.nn.initializers.base.Initializer[source]#
post_tau: float | jax.Array | spark.nn.initializers.base.Initializer[source]#
post_slow_tau: float | jax.Array | spark.nn.initializers.base.Initializer[source]#
target_tau: float | jax.Array | spark.nn.initializers.base.Initializer[source]#
a: float | jax.Array[source]#
b: float | jax.Array[source]#
c: float | jax.Array[source]#
d: float | jax.Array[source]#
p: float | jax.Array[source]#
eta: float | jax.Array[source]#
class spark.nn.components.learning_rules.HebbianRule(config=None, **kwargs)[source]#

Bases: spark.nn.components.learning_rules.base.LearningRule

Hebbian plasticy rule model.

Init:

pre_tau: float | jax.Array post_tau: float | jax.Array gamma: float | jax.Array

Input:

pre_spikes: SpikeArray post_spikes: SpikeArray kernel: FloatArray

Output:

kernel: FloatArray

Parameters:

config (HebbianRuleConfig | None)

config: HebbianRuleConfig[source]#
build(input_specs)[source]#

Build method.

Parameters:

input_specs (dict[str, spark.core.specs.PortSpecs])

reset()[source]#

Resets component state.

Return type:

None

__call__(pre_spikes, post_spikes, kernel)[source]#

Computes and returns the next kernel update.

Parameters:
Return type:

spark.nn.components.learning_rules.base.LearningRuleOutput

class spark.nn.components.learning_rules.HebbianRuleConfig(__skip_validation__=False, **kwargs)[source]#

Bases: spark.nn.components.learning_rules.base.LearningRuleConfig

HebbianRule configuration class.

Parameters:

__skip_validation__ (bool)

pre_tau: float | jax.Array | spark.nn.initializers.base.Initializer[source]#
post_tau: float | jax.Array | spark.nn.initializers.base.Initializer[source]#
eta: float[source]#
class spark.nn.components.learning_rules.OjaRule(config=None, **kwargs)[source]#

Bases: spark.nn.components.learning_rules.base.LearningRule

Oja’s plasticy rule model.

Init:

pre_tau: float | jax.Array post_tau: float | jax.Array gamma: float | jax.Array

Input:

pre_spikes: SpikeArray post_spikes: SpikeArray kernel: FloatArray

Output:

kernel: FloatArray

Parameters:

config (OjaRuleConfig | None)

config: OjaRuleConfig[source]#
build(input_specs)[source]#

Build method.

Parameters:

input_specs (dict[str, spark.core.specs.PortSpecs])

reset()[source]#

Resets component state.

Return type:

None

__call__(pre_spikes, post_spikes, kernel)[source]#

Computes and returns the next kernel update.

Parameters:
Return type:

spark.nn.components.learning_rules.base.LearningRuleOutput

class spark.nn.components.learning_rules.OjaRuleConfig(__skip_validation__=False, **kwargs)[source]#

Bases: spark.nn.components.learning_rules.base.LearningRuleConfig

Abstract learning rule configuration class.

Parameters:

__skip_validation__ (bool)

post_tau: float | jax.Array | spark.nn.initializers.base.Initializer[source]#
eta: float[source]#
class spark.nn.components.learning_rules.QuadrupletRule(config=None, **kwargs)[source]#

Bases: spark.nn.components.learning_rules.base.LearningRule

Quadruplet plasticy rule model.

Init:

pre_tau: float | jax.Array post_tau: float | jax.Array q_alpha: float | jax.Array q_beta: float | jax.Array q_gamma: float | jax.Array q_delta: float | jax.Array eta: float | jax.Array

Input:

modulation: FloatArray pre_spikes: SpikeArray post_spikes: SpikeArray kernel: FloatArray

Output:

kernel: FloatArray

Parameters:

config (QuadrupletRuleConfig | None)

config: QuadrupletRuleConfig[source]#
build(input_specs)[source]#

Build method.

Parameters:

input_specs (dict[str, spark.core.specs.PortSpecs])

reset()[source]#

Resets component state.

Return type:

None

__call__(modulation, pre_spikes, post_spikes, kernel)[source]#

Computes and returns the next kernel update.

Parameters:
Return type:

spark.nn.components.learning_rules.base.LearningRuleOutput

class spark.nn.components.learning_rules.QuadrupletRuleConfig(__skip_validation__=False, **kwargs)[source]#

Bases: spark.nn.components.learning_rules.base.LearningRuleConfig

QuadrupletRule configuration class.

Parameters:

__skip_validation__ (bool)

pre_tau: float | jax.Array[source]#
post_tau: float | jax.Array[source]#
q_alpha: float | jax.Array[source]#
q_beta: float | jax.Array[source]#
q_gamma: float | jax.Array[source]#
q_delta: float | jax.Array[source]#
gamma: float | jax.Array[source]#
class spark.nn.components.learning_rules.QuadrupletRuleTensor(config=None, **kwargs)[source]#

Bases: spark.nn.components.learning_rules.base.LearningRule

Quadruplet plasticy rule model (tensor).

Init:

pre_tau: float | jax.Array post_tau: float | jax.Array eta: float | jax.Array

Input:

modulation: FloatArray pre_spikes: SpikeArray post_spikes: SpikeArray kernel: FloatArray

Output:

kernel: FloatArray

Parameters:

config (QuadrupletRuleTensorConfig | None)

config: QuadrupletRuleTensorConfig[source]#
build(input_specs)[source]#

Build method.

Parameters:

input_specs (dict[str, spark.core.specs.PortSpecs])

reset()[source]#

Resets component state.

Return type:

None

__call__(modulation, pre_spikes, post_spikes, kernel)[source]#

Computes and returns the next kernel update.

Parameters:
Return type:

spark.nn.components.learning_rules.base.LearningRuleOutput

class spark.nn.components.learning_rules.QuadrupletRuleTensorConfig(__skip_validation__=False, **kwargs)[source]#

Bases: spark.nn.components.learning_rules.base.LearningRuleConfig

QuadrupletRuleTensor configuration class.

Parameters:

__skip_validation__ (bool)

pre_tau: tuple[float, float, float, float][source]#
post_tau: tuple[float, float, float, float][source]#
q_alpha: tuple[float, float, float, float][source]#
q_beta: tuple[float, float, float, float][source]#
q_gamma: tuple[float, float, float, float][source]#
q_delta: tuple[float, float, float, float][source]#
max_clip: tuple[float, float, float, float][source]#
eta: float[source]#
class spark.nn.components.learning_rules.ThreeFactorHebbianRule(config=None, **kwargs)[source]#

Bases: spark.nn.components.learning_rules.base.LearningRule

Three-factor Hebbian plasticy rule model.

Init:

pre_tau: float | jax.Array post_tau: float | jax.Array gamma: float | jax.Array

Input:

pre_spikes: SpikeArray post_spikes: SpikeArray kernel: FloatArray

Output:

kernel: FloatArray

Parameters:

config (ThreeFactorHebbianRuleConfig | None)

config: ThreeFactorHebbianRuleConfig[source]#
build(input_specs)[source]#

Build method.

Parameters:

input_specs (dict[str, spark.core.specs.PortSpecs])

reset()[source]#

Resets component state.

Return type:

None

__call__(reward, pre_spikes, post_spikes, kernel)[source]#

Computes and returns the next kernel update.

Parameters:
Return type:

spark.nn.components.learning_rules.base.LearningRuleOutput

class spark.nn.components.learning_rules.ThreeFactorHebbianRuleConfig(__skip_validation__=False, **kwargs)[source]#

Bases: spark.nn.components.learning_rules.base.LearningRuleConfig

ThreeFactorHebbianRule configuration class.

Parameters:

__skip_validation__ (bool)

pre_tau: float | jax.Array | spark.nn.initializers.base.Initializer[source]#
post_tau: float | jax.Array | spark.nn.initializers.base.Initializer[source]#
gamma: float | jax.Array[source]#