spark.nn.components.learning_rules.zenke_rule#

Classes#

ZenkeRuleConfig

ZenkeRule configuration class.

ZenkeRule

Zenke plasticy rule model. This model is an extension of the classic Hebbian Rule.

Module Contents#

class spark.nn.components.learning_rules.zenke_rule.ZenkeRuleConfig(__skip_validation__=False, **kwargs)[source]#

Bases: spark.nn.components.learning_rules.base.LearningRuleConfig

ZenkeRule configuration class.

Parameters:

__skip_validation__ (bool)

pre_tau: float | jax.Array | spark.nn.initializers.base.Initializer[source]#
post_tau: float | jax.Array | spark.nn.initializers.base.Initializer[source]#
post_slow_tau: float | jax.Array | spark.nn.initializers.base.Initializer[source]#
target_tau: float | jax.Array | spark.nn.initializers.base.Initializer[source]#
a: float | jax.Array[source]#
b: float | jax.Array[source]#
c: float | jax.Array[source]#
d: float | jax.Array[source]#
p: float | jax.Array[source]#
eta: float | jax.Array[source]#
class spark.nn.components.learning_rules.zenke_rule.ZenkeRule(config=None, **kwargs)[source]#

Bases: spark.nn.components.learning_rules.base.LearningRule

Zenke plasticy rule model. This model is an extension of the classic Hebbian Rule.

Init:

pre_tau: float | jax.Array post_tau: float | jax.Array post_slow_tau: float | jax.Array target_tau: float | jax.Array a: float | jax.Array b: float | jax.Array c: float | jax.Array d: float | jax.Array P: float | jax.Array eta: float | jax.Array

Input:

pre_spikes: SpikeArray post_spikes: SpikeArray kernel: FloatArray

Output:

kernel: FloatArray

Parameters:

config (ZenkeRuleConfig | None)

config: ZenkeRuleConfig[source]#
build(input_specs)[source]#

Build method.

Parameters:

input_specs (dict[str, spark.core.specs.PortSpecs])

reset()[source]#

Resets component state.

Return type:

None

__call__(pre_spikes, post_spikes, kernel)[source]#

Computes and returns the next kernel update.

Parameters:
Return type:

spark.nn.components.learning_rules.base.LearningRuleOutput