spark.nn.components.learning_rules.three_factor_rule#
Classes#
ThreeFactorHebbianRule configuration class. |
|
Three-factor Hebbian plasticy rule model. |
Module Contents#
- class spark.nn.components.learning_rules.three_factor_rule.ThreeFactorHebbianRuleConfig(__skip_validation__=False, **kwargs)[source]#
Bases:
spark.nn.components.learning_rules.base.LearningRuleConfigThreeFactorHebbianRule configuration class.
- Parameters:
__skip_validation__ (bool)
- post_tau: float | jax.Array | spark.nn.initializers.base.Initializer[source]#
- class spark.nn.components.learning_rules.three_factor_rule.ThreeFactorHebbianRule(config=None, **kwargs)[source]#
Bases:
spark.nn.components.learning_rules.base.LearningRuleThree-factor Hebbian plasticy rule model.
- Init:
pre_tau: float | jax.Array post_tau: float | jax.Array gamma: float | jax.Array
- Input:
pre_spikes: SpikeArray post_spikes: SpikeArray kernel: FloatArray
- Output:
kernel: FloatArray
- Parameters:
config (ThreeFactorHebbianRuleConfig | None)
- build(input_specs)[source]#
Build method.
- Parameters:
input_specs (dict[str, spark.core.specs.PortSpecs])
- __call__(reward, pre_spikes, post_spikes, kernel)[source]#
Computes and returns the next kernel update.
- Parameters:
reward (spark.core.payloads.FloatArray)
pre_spikes (spark.core.payloads.SpikeArray)
post_spikes (spark.core.payloads.SpikeArray)
kernel (spark.core.payloads.FloatArray)
- Return type: