spark.nn.components.learning_rules.three_factor_rule#

Classes#

ThreeFactorHebbianRuleConfig

ThreeFactorHebbianRule configuration class.

ThreeFactorHebbianRule

Three-factor Hebbian plasticy rule model.

Module Contents#

class spark.nn.components.learning_rules.three_factor_rule.ThreeFactorHebbianRuleConfig(__skip_validation__=False, **kwargs)[source]#

Bases: spark.nn.components.learning_rules.base.LearningRuleConfig

ThreeFactorHebbianRule configuration class.

Parameters:

__skip_validation__ (bool)

pre_tau: float | jax.Array | spark.nn.initializers.base.Initializer[source]#
post_tau: float | jax.Array | spark.nn.initializers.base.Initializer[source]#
gamma: float | jax.Array[source]#
class spark.nn.components.learning_rules.three_factor_rule.ThreeFactorHebbianRule(config=None, **kwargs)[source]#

Bases: spark.nn.components.learning_rules.base.LearningRule

Three-factor Hebbian plasticy rule model.

Init:

pre_tau: float | jax.Array post_tau: float | jax.Array gamma: float | jax.Array

Input:

pre_spikes: SpikeArray post_spikes: SpikeArray kernel: FloatArray

Output:

kernel: FloatArray

Parameters:

config (ThreeFactorHebbianRuleConfig | None)

config: ThreeFactorHebbianRuleConfig[source]#
build(input_specs)[source]#

Build method.

Parameters:

input_specs (dict[str, spark.core.specs.PortSpecs])

reset()[source]#

Resets component state.

Return type:

None

__call__(reward, pre_spikes, post_spikes, kernel)[source]#

Computes and returns the next kernel update.

Parameters:
Return type:

spark.nn.components.learning_rules.base.LearningRuleOutput