spark.nn.components.learning_rules.hebbian_rule#

Classes#

HebbianRuleConfig

HebbianRule configuration class.

HebbianRule

Hebbian plasticy rule model.

OjaRuleConfig

Abstract learning rule configuration class.

OjaRule

Oja's plasticy rule model.

Module Contents#

class spark.nn.components.learning_rules.hebbian_rule.HebbianRuleConfig(__skip_validation__=False, **kwargs)[source]#

Bases: spark.nn.components.learning_rules.base.LearningRuleConfig

HebbianRule configuration class.

Parameters:

__skip_validation__ (bool)

pre_tau: float | jax.Array | spark.nn.initializers.base.Initializer[source]#
post_tau: float | jax.Array | spark.nn.initializers.base.Initializer[source]#
eta: float[source]#
class spark.nn.components.learning_rules.hebbian_rule.HebbianRule(config=None, **kwargs)[source]#

Bases: spark.nn.components.learning_rules.base.LearningRule

Hebbian plasticy rule model.

Init:

pre_tau: float | jax.Array post_tau: float | jax.Array gamma: float | jax.Array

Input:

pre_spikes: SpikeArray post_spikes: SpikeArray kernel: FloatArray

Output:

kernel: FloatArray

Parameters:

config (HebbianRuleConfig | None)

config: HebbianRuleConfig[source]#
build(input_specs)[source]#

Build method.

Parameters:

input_specs (dict[str, spark.core.specs.PortSpecs])

reset()[source]#

Resets component state.

Return type:

None

__call__(pre_spikes, post_spikes, kernel)[source]#

Computes and returns the next kernel update.

Parameters:
Return type:

spark.nn.components.learning_rules.base.LearningRuleOutput

class spark.nn.components.learning_rules.hebbian_rule.OjaRuleConfig(__skip_validation__=False, **kwargs)[source]#

Bases: spark.nn.components.learning_rules.base.LearningRuleConfig

Abstract learning rule configuration class.

Parameters:

__skip_validation__ (bool)

post_tau: float | jax.Array | spark.nn.initializers.base.Initializer[source]#
eta: float[source]#
class spark.nn.components.learning_rules.hebbian_rule.OjaRule(config=None, **kwargs)[source]#

Bases: spark.nn.components.learning_rules.base.LearningRule

Oja’s plasticy rule model.

Init:

pre_tau: float | jax.Array post_tau: float | jax.Array gamma: float | jax.Array

Input:

pre_spikes: SpikeArray post_spikes: SpikeArray kernel: FloatArray

Output:

kernel: FloatArray

Parameters:

config (OjaRuleConfig | None)

config: OjaRuleConfig[source]#
build(input_specs)[source]#

Build method.

Parameters:

input_specs (dict[str, spark.core.specs.PortSpecs])

reset()[source]#

Resets component state.

Return type:

None

__call__(pre_spikes, post_spikes, kernel)[source]#

Computes and returns the next kernel update.

Parameters:
Return type:

spark.nn.components.learning_rules.base.LearningRuleOutput