spark.nn.components.learning_rules.quadruplet_rule#
Classes#
QuadrupletRule configuration class. |
|
Quadruplet plasticy rule model. |
|
QuadrupletRuleTensor configuration class. |
|
Quadruplet plasticy rule model (tensor). |
Module Contents#
- class spark.nn.components.learning_rules.quadruplet_rule.QuadrupletRuleConfig(__skip_validation__=False, **kwargs)[source]#
Bases:
spark.nn.components.learning_rules.base.LearningRuleConfigQuadrupletRule configuration class.
- Parameters:
__skip_validation__ (bool)
- class spark.nn.components.learning_rules.quadruplet_rule.QuadrupletRule(config=None, **kwargs)[source]#
Bases:
spark.nn.components.learning_rules.base.LearningRuleQuadruplet plasticy rule model.
- Init:
pre_tau: float | jax.Array post_tau: float | jax.Array q_alpha: float | jax.Array q_beta: float | jax.Array q_gamma: float | jax.Array q_delta: float | jax.Array eta: float | jax.Array
- Input:
modulation: FloatArray pre_spikes: SpikeArray post_spikes: SpikeArray kernel: FloatArray
- Output:
kernel: FloatArray
- Parameters:
config (QuadrupletRuleConfig | None)
- config: QuadrupletRuleConfig[source]#
- build(input_specs)[source]#
Build method.
- Parameters:
input_specs (dict[str, spark.core.specs.PortSpecs])
- __call__(modulation, pre_spikes, post_spikes, kernel)[source]#
Computes and returns the next kernel update.
- Parameters:
modulation (spark.core.payloads.FloatArray)
pre_spikes (spark.core.payloads.SpikeArray)
post_spikes (spark.core.payloads.SpikeArray)
kernel (spark.core.payloads.FloatArray)
- Return type:
- class spark.nn.components.learning_rules.quadruplet_rule.QuadrupletRuleTensorConfig(__skip_validation__=False, **kwargs)[source]#
Bases:
spark.nn.components.learning_rules.base.LearningRuleConfigQuadrupletRuleTensor configuration class.
- Parameters:
__skip_validation__ (bool)
- class spark.nn.components.learning_rules.quadruplet_rule.QuadrupletRuleTensor(config=None, **kwargs)[source]#
Bases:
spark.nn.components.learning_rules.base.LearningRuleQuadruplet plasticy rule model (tensor).
- Init:
pre_tau: float | jax.Array post_tau: float | jax.Array eta: float | jax.Array
- Input:
modulation: FloatArray pre_spikes: SpikeArray post_spikes: SpikeArray kernel: FloatArray
- Output:
kernel: FloatArray
- Parameters:
config (QuadrupletRuleTensorConfig | None)
- build(input_specs)[source]#
Build method.
- Parameters:
input_specs (dict[str, spark.core.specs.PortSpecs])
- __call__(modulation, pre_spikes, post_spikes, kernel)[source]#
Computes and returns the next kernel update.
- Parameters:
modulation (spark.core.payloads.FloatArray)
pre_spikes (spark.core.payloads.SpikeArray)
post_spikes (spark.core.payloads.SpikeArray)
kernel (spark.core.payloads.FloatArray)
- Return type: