spark.nn.components.learning_rules.hebbian_rule#
Classes#
HebbianRule configuration class. |
|
Hebbian plasticy rule model. |
|
Abstract learning rule configuration class. |
|
Oja's plasticy rule model. |
Module Contents#
- class spark.nn.components.learning_rules.hebbian_rule.HebbianRuleConfig(__skip_validation__=False, **kwargs)[source]#
Bases:
spark.nn.components.learning_rules.base.LearningRuleConfigHebbianRule configuration class.
- Parameters:
__skip_validation__ (bool)
- post_tau: float | jax.Array | spark.nn.initializers.base.Initializer[source]#
- class spark.nn.components.learning_rules.hebbian_rule.HebbianRule(config=None, **kwargs)[source]#
Bases:
spark.nn.components.learning_rules.base.LearningRuleHebbian plasticy rule model.
- Init:
pre_tau: float | jax.Array post_tau: float | jax.Array gamma: float | jax.Array
- Input:
pre_spikes: SpikeArray post_spikes: SpikeArray kernel: FloatArray
- Output:
kernel: FloatArray
- Parameters:
config (HebbianRuleConfig | None)
- config: HebbianRuleConfig[source]#
- build(input_specs)[source]#
Build method.
- Parameters:
input_specs (dict[str, spark.core.specs.PortSpecs])
- __call__(pre_spikes, post_spikes, kernel)[source]#
Computes and returns the next kernel update.
- Parameters:
pre_spikes (spark.core.payloads.SpikeArray)
post_spikes (spark.core.payloads.SpikeArray)
kernel (spark.core.payloads.FloatArray)
- Return type:
- class spark.nn.components.learning_rules.hebbian_rule.OjaRuleConfig(__skip_validation__=False, **kwargs)[source]#
Bases:
spark.nn.components.learning_rules.base.LearningRuleConfigAbstract learning rule configuration class.
- Parameters:
__skip_validation__ (bool)
- post_tau: float | jax.Array | spark.nn.initializers.base.Initializer[source]#
- class spark.nn.components.learning_rules.hebbian_rule.OjaRule(config=None, **kwargs)[source]#
Bases:
spark.nn.components.learning_rules.base.LearningRuleOja’s plasticy rule model.
- Init:
pre_tau: float | jax.Array post_tau: float | jax.Array gamma: float | jax.Array
- Input:
pre_spikes: SpikeArray post_spikes: SpikeArray kernel: FloatArray
- Output:
kernel: FloatArray
- Parameters:
config (OjaRuleConfig | None)
- config: OjaRuleConfig[source]#
- build(input_specs)[source]#
Build method.
- Parameters:
input_specs (dict[str, spark.core.specs.PortSpecs])
- __call__(pre_spikes, post_spikes, kernel)[source]#
Computes and returns the next kernel update.
- Parameters:
pre_spikes (spark.core.payloads.SpikeArray)
post_spikes (spark.core.payloads.SpikeArray)
kernel (spark.core.payloads.FloatArray)
- Return type: