spark.nn.components.learning_rules#
Submodules#
Classes#
Abstract learning rule model. |
|
Abstract learning rule configuration class. |
|
Generic learning rule model output spec. |
|
Zenke plasticy rule model. This model is an extension of the classic Hebbian Rule. |
|
ZenkeRule configuration class. |
|
Hebbian plasticy rule model. |
|
HebbianRule configuration class. |
|
Oja's plasticy rule model. |
|
Abstract learning rule configuration class. |
|
Quadruplet plasticy rule model. |
|
QuadrupletRule configuration class. |
|
Quadruplet plasticy rule model (tensor). |
|
QuadrupletRuleTensor configuration class. |
|
Three-factor Hebbian plasticy rule model. |
|
ThreeFactorHebbianRule configuration class. |
Package Contents#
- class spark.nn.components.learning_rules.LearningRule(config=None, **kwargs)[source]#
Bases:
spark.nn.components.base.Component,Generic[ConfigT]Abstract learning rule model.
- Parameters:
config (ConfigT | None)
- class spark.nn.components.learning_rules.LearningRuleConfig(__skip_validation__=False, **kwargs)[source]#
Bases:
spark.nn.components.base.ComponentConfigAbstract learning rule configuration class.
- Parameters:
__skip_validation__ (bool)
- class spark.nn.components.learning_rules.LearningRuleOutput[source]#
Bases:
TypedDictGeneric learning rule model output spec.
Initialize self. See help(type(self)) for accurate signature.
- class spark.nn.components.learning_rules.ZenkeRule(config=None, **kwargs)[source]#
Bases:
spark.nn.components.learning_rules.base.LearningRuleZenke plasticy rule model. This model is an extension of the classic Hebbian Rule.
- Init:
pre_tau: float | jax.Array post_tau: float | jax.Array post_slow_tau: float | jax.Array target_tau: float | jax.Array a: float | jax.Array b: float | jax.Array c: float | jax.Array d: float | jax.Array P: float | jax.Array eta: float | jax.Array
- Input:
pre_spikes: SpikeArray post_spikes: SpikeArray kernel: FloatArray
- Output:
kernel: FloatArray
- Parameters:
config (ZenkeRuleConfig | None)
- config: ZenkeRuleConfig[source]#
- build(input_specs)[source]#
Build method.
- Parameters:
input_specs (dict[str, spark.core.specs.PortSpecs])
- __call__(pre_spikes, post_spikes, kernel)[source]#
Computes and returns the next kernel update.
- Parameters:
pre_spikes (spark.core.payloads.SpikeArray)
post_spikes (spark.core.payloads.SpikeArray)
kernel (spark.core.payloads.FloatArray)
- Return type:
- class spark.nn.components.learning_rules.ZenkeRuleConfig(__skip_validation__=False, **kwargs)[source]#
Bases:
spark.nn.components.learning_rules.base.LearningRuleConfigZenkeRule configuration class.
- Parameters:
__skip_validation__ (bool)
- post_tau: float | jax.Array | spark.nn.initializers.base.Initializer[source]#
- post_slow_tau: float | jax.Array | spark.nn.initializers.base.Initializer[source]#
- target_tau: float | jax.Array | spark.nn.initializers.base.Initializer[source]#
- class spark.nn.components.learning_rules.HebbianRule(config=None, **kwargs)[source]#
Bases:
spark.nn.components.learning_rules.base.LearningRuleHebbian plasticy rule model.
- Init:
pre_tau: float | jax.Array post_tau: float | jax.Array gamma: float | jax.Array
- Input:
pre_spikes: SpikeArray post_spikes: SpikeArray kernel: FloatArray
- Output:
kernel: FloatArray
- Parameters:
config (HebbianRuleConfig | None)
- config: HebbianRuleConfig[source]#
- build(input_specs)[source]#
Build method.
- Parameters:
input_specs (dict[str, spark.core.specs.PortSpecs])
- __call__(pre_spikes, post_spikes, kernel)[source]#
Computes and returns the next kernel update.
- Parameters:
pre_spikes (spark.core.payloads.SpikeArray)
post_spikes (spark.core.payloads.SpikeArray)
kernel (spark.core.payloads.FloatArray)
- Return type:
- class spark.nn.components.learning_rules.HebbianRuleConfig(__skip_validation__=False, **kwargs)[source]#
Bases:
spark.nn.components.learning_rules.base.LearningRuleConfigHebbianRule configuration class.
- Parameters:
__skip_validation__ (bool)
- post_tau: float | jax.Array | spark.nn.initializers.base.Initializer[source]#
- class spark.nn.components.learning_rules.OjaRule(config=None, **kwargs)[source]#
Bases:
spark.nn.components.learning_rules.base.LearningRuleOja’s plasticy rule model.
- Init:
pre_tau: float | jax.Array post_tau: float | jax.Array gamma: float | jax.Array
- Input:
pre_spikes: SpikeArray post_spikes: SpikeArray kernel: FloatArray
- Output:
kernel: FloatArray
- Parameters:
config (OjaRuleConfig | None)
- config: OjaRuleConfig[source]#
- build(input_specs)[source]#
Build method.
- Parameters:
input_specs (dict[str, spark.core.specs.PortSpecs])
- __call__(pre_spikes, post_spikes, kernel)[source]#
Computes and returns the next kernel update.
- Parameters:
pre_spikes (spark.core.payloads.SpikeArray)
post_spikes (spark.core.payloads.SpikeArray)
kernel (spark.core.payloads.FloatArray)
- Return type:
- class spark.nn.components.learning_rules.OjaRuleConfig(__skip_validation__=False, **kwargs)[source]#
Bases:
spark.nn.components.learning_rules.base.LearningRuleConfigAbstract learning rule configuration class.
- Parameters:
__skip_validation__ (bool)
- post_tau: float | jax.Array | spark.nn.initializers.base.Initializer[source]#
- class spark.nn.components.learning_rules.QuadrupletRule(config=None, **kwargs)[source]#
Bases:
spark.nn.components.learning_rules.base.LearningRuleQuadruplet plasticy rule model.
- Init:
pre_tau: float | jax.Array post_tau: float | jax.Array q_alpha: float | jax.Array q_beta: float | jax.Array q_gamma: float | jax.Array q_delta: float | jax.Array eta: float | jax.Array
- Input:
modulation: FloatArray pre_spikes: SpikeArray post_spikes: SpikeArray kernel: FloatArray
- Output:
kernel: FloatArray
- Parameters:
config (QuadrupletRuleConfig | None)
- config: QuadrupletRuleConfig[source]#
- build(input_specs)[source]#
Build method.
- Parameters:
input_specs (dict[str, spark.core.specs.PortSpecs])
- __call__(modulation, pre_spikes, post_spikes, kernel)[source]#
Computes and returns the next kernel update.
- Parameters:
modulation (spark.core.payloads.FloatArray)
pre_spikes (spark.core.payloads.SpikeArray)
post_spikes (spark.core.payloads.SpikeArray)
kernel (spark.core.payloads.FloatArray)
- Return type:
- class spark.nn.components.learning_rules.QuadrupletRuleConfig(__skip_validation__=False, **kwargs)[source]#
Bases:
spark.nn.components.learning_rules.base.LearningRuleConfigQuadrupletRule configuration class.
- Parameters:
__skip_validation__ (bool)
- class spark.nn.components.learning_rules.QuadrupletRuleTensor(config=None, **kwargs)[source]#
Bases:
spark.nn.components.learning_rules.base.LearningRuleQuadruplet plasticy rule model (tensor).
- Init:
pre_tau: float | jax.Array post_tau: float | jax.Array eta: float | jax.Array
- Input:
modulation: FloatArray pre_spikes: SpikeArray post_spikes: SpikeArray kernel: FloatArray
- Output:
kernel: FloatArray
- Parameters:
config (QuadrupletRuleTensorConfig | None)
- build(input_specs)[source]#
Build method.
- Parameters:
input_specs (dict[str, spark.core.specs.PortSpecs])
- __call__(modulation, pre_spikes, post_spikes, kernel)[source]#
Computes and returns the next kernel update.
- Parameters:
modulation (spark.core.payloads.FloatArray)
pre_spikes (spark.core.payloads.SpikeArray)
post_spikes (spark.core.payloads.SpikeArray)
kernel (spark.core.payloads.FloatArray)
- Return type:
- class spark.nn.components.learning_rules.QuadrupletRuleTensorConfig(__skip_validation__=False, **kwargs)[source]#
Bases:
spark.nn.components.learning_rules.base.LearningRuleConfigQuadrupletRuleTensor configuration class.
- Parameters:
__skip_validation__ (bool)
- class spark.nn.components.learning_rules.ThreeFactorHebbianRule(config=None, **kwargs)[source]#
Bases:
spark.nn.components.learning_rules.base.LearningRuleThree-factor Hebbian plasticy rule model.
- Init:
pre_tau: float | jax.Array post_tau: float | jax.Array gamma: float | jax.Array
- Input:
pre_spikes: SpikeArray post_spikes: SpikeArray kernel: FloatArray
- Output:
kernel: FloatArray
- Parameters:
config (ThreeFactorHebbianRuleConfig | None)
- build(input_specs)[source]#
Build method.
- Parameters:
input_specs (dict[str, spark.core.specs.PortSpecs])
- __call__(reward, pre_spikes, post_spikes, kernel)[source]#
Computes and returns the next kernel update.
- Parameters:
reward (spark.core.payloads.FloatArray)
pre_spikes (spark.core.payloads.SpikeArray)
post_spikes (spark.core.payloads.SpikeArray)
kernel (spark.core.payloads.FloatArray)
- Return type:
- class spark.nn.components.learning_rules.ThreeFactorHebbianRuleConfig(__skip_validation__=False, **kwargs)[source]#
Bases:
spark.nn.components.learning_rules.base.LearningRuleConfigThreeFactorHebbianRule configuration class.
- Parameters:
__skip_validation__ (bool)
- post_tau: float | jax.Array | spark.nn.initializers.base.Initializer[source]#